Merge branch 'develop' into push_badge_counts

This commit is contained in:
David Baker 2016-01-19 18:17:23 +00:00
commit afb7b377f2
35 changed files with 1043 additions and 1246 deletions

View file

@ -258,6 +258,14 @@ During setup of Synapse you need to call python2.7 directly again::
...substituting your host and domain name as appropriate.
FreeBSD
-------
Synapse can be installed via FreeBSD Ports or Packages:
- Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean``
- Packages: ``pkg install py27-matrix-synapse``
Windows Install
---------------
Synapse can be installed on Cygwin. It requires the following Cygwin packages:

View file

@ -510,37 +510,14 @@ class Auth(object):
"""
# Can optionally look elsewhere in the request (e.g. headers)
try:
access_token = request.args["access_token"][0]
# Check for application service tokens with a user_id override
try:
app_service = yield self.store.get_app_service_by_token(
access_token
)
if not app_service:
raise KeyError
user_id = app_service.sender
if "user_id" in request.args:
user_id = request.args["user_id"][0]
if not app_service.is_interested_in_user(user_id):
raise AuthError(
403,
"Application service cannot masquerade as this user."
)
if not user_id:
raise KeyError
user_id = yield self._get_appservice_user_id(request.args)
if user_id:
request.authenticated_entity = user_id
defer.returnValue(
Requester(UserID.from_string(user_id), "", False)
)
return
except KeyError:
pass # normal users won't have the user_id query parameter set.
access_token = request.args["access_token"][0]
user_info = yield self._get_user_by_access_token(access_token)
user = user_info["user"]
token_id = user_info["token_id"]
@ -573,6 +550,33 @@ class Auth(object):
errcode=Codes.MISSING_TOKEN
)
@defer.inlineCallbacks
def _get_appservice_user_id(self, request_args):
app_service = yield self.store.get_app_service_by_token(
request_args["access_token"][0]
)
if app_service is None:
defer.returnValue(None)
if "user_id" not in request_args:
defer.returnValue(app_service.sender)
user_id = request_args["user_id"][0]
if app_service.sender == user_id:
defer.returnValue(app_service.sender)
if not app_service.is_interested_in_user(user_id):
raise AuthError(
403,
"Application service cannot masquerade as this user."
)
if not (yield self.store.get_user_by_id(user_id)):
raise AuthError(
403,
"Application service has not registered this user"
)
defer.returnValue(user_id)
@defer.inlineCallbacks
def _get_user_by_access_token(self, token):
""" Get a registered user's ID.

View file

@ -29,6 +29,7 @@ class Codes(object):
USER_IN_USE = "M_USER_IN_USE"
ROOM_IN_USE = "M_ROOM_IN_USE"
BAD_PAGINATION = "M_BAD_PAGINATION"
BAD_STATE = "M_BAD_STATE"
UNKNOWN = "M_UNKNOWN"
NOT_FOUND = "M_NOT_FOUND"
MISSING_TOKEN = "M_MISSING_TOKEN"
@ -42,6 +43,7 @@ class Codes(object):
EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "THREEPID_IN_USE"
INVALID_USERNAME = "M_INVALID_USERNAME"
class CodeMessageException(RuntimeError):

View file

@ -88,6 +88,9 @@ import time
logger = logging.getLogger("synapse.app.homeserver")
ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
@ -495,9 +498,8 @@ class SynapseRequest(Request):
)
def get_redacted_uri(self):
return re.sub(
r'(\?.*access_token=)[^&]*(.*)$',
r'\1<redacted>\2',
return ACCESS_TOKEN_RE.sub(
r'\1<redacted>\3',
self.uri
)

View file

@ -117,6 +117,15 @@ class EventBase(object):
def __set__(self, instance, value):
raise AttributeError("Unrecognized attribute %s" % (instance,))
def __getitem__(self, field):
return self._event_dict[field]
def __contains__(self, field):
return field in self._event_dict
def items(self):
return self._event_dict.items()
class FrozenEvent(EventBase):
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):

View file

@ -53,16 +53,54 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_guest=False):
# Assumes that user has at some point joined the room if not is_guest.
def _filter_events_for_clients(self, users, events):
""" Returns dict of user_id -> list of events that user is allowed to
see.
"""
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, None),
)
)
forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room(
room_id,
)
for room_id in frozenset(e.room_id for e in events)
], consumeErrors=True)
# Set of membership event_ids that have been forgotten
event_id_forgotten = frozenset(
row["event_id"] for rows in forgotten for row in rows
)
def allowed(event, user_id, is_guest):
state = event_id_to_state[event.event_id]
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared")
else:
visibility = "shared"
def allowed(event, membership, visibility):
if visibility == "world_readable":
return True
if is_guest:
return False
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
if membership_event.event_id in event_id_forgotten:
membership = None
else:
membership = membership_event.membership
else:
membership = None
if membership == Membership.JOIN:
return True
@ -78,43 +116,20 @@ class BaseHandler(object):
return True
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
)
defer.returnValue({
user_id: [
event
for event in events
if allowed(event, user_id, is_guest)
]
for user_id, is_guest in users
})
events_to_return = []
for event in events:
state = event_id_to_state[event.event_id]
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
was_forgotten_at_event = yield self.store.was_forgotten_at(
membership_event.state_key,
membership_event.room_id,
membership_event.event_id
)
if was_forgotten_at_event:
membership = None
else:
membership = membership_event.membership
else:
membership = None
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared")
else:
visibility = "shared"
should_include = allowed(event, membership, visibility)
if should_include:
events_to_return.append(event)
defer.returnValue(events_to_return)
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_guest=False):
# Assumes that user has at some point joined the room if not is_guest.
res = yield self._filter_events_for_clients([(user_id, is_guest)], events)
defer.returnValue(res.get(user_id, []))
def ratelimit(self, user_id):
time_now = self.clock.time()
@ -171,12 +186,10 @@ class BaseHandler(object):
)
@defer.inlineCallbacks
def handle_new_client_event(self, event, context, extra_destinations=[],
extra_users=[], suppress_auth=False):
def handle_new_client_event(self, event, context, extra_users=[]):
# We now need to go and hit out to wherever we need to hit out to.
if not suppress_auth:
self.auth.check(event, auth_events=context.current_state)
self.auth.check(event, auth_events=context.current_state)
yield self.maybe_kick_guest_users(event, context.current_state.values())
@ -253,12 +266,12 @@ class BaseHandler(object):
event, context=context
)
action_generator = ActionGenerator(self.store)
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, self
)
destinations = set(extra_destinations)
destinations = set()
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:

View file

@ -245,7 +245,7 @@ class FederationHandler(BaseHandler):
yield user_joined_room(self.distributor, user, event.room_id)
if not backfilled and not event.internal_metadata.is_outlier():
action_generator = ActionGenerator(self.store)
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, self
)
@ -1692,7 +1692,7 @@ class FederationHandler(BaseHandler):
self.auth.check(event, context.current_state)
yield self._validate_keyserver(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context)
yield member_handler.send_membership_event(event, context)
else:
destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)])
yield self.replication_layer.forward_third_party_invite(
@ -1721,7 +1721,7 @@ class FederationHandler(BaseHandler):
# TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context)
yield member_handler.send_membership_event(event, context)
@defer.inlineCallbacks
def add_display_name_to_third_party_invite(self, event_dict, event, context):

View file

@ -174,30 +174,25 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True,
token_id=None, txn_id=None, is_guest=False):
""" Given a dict from a client, create and handle a new event.
def create_event(self, event_dict, token_id=None, txn_id=None):
"""
Given a dict from a client, create a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
etc.
Adds display names to Join membership events.
Persists and notifies local clients and federation.
Args:
event_dict (dict): An entire event
Returns:
Tuple of created event (FrozenEvent), Context
"""
builder = self.event_builder_factory.new(event_dict)
self.validator.validate_new(builder)
if ratelimit:
self.ratelimit(builder.user_id)
# TODO(paul): Why does 'event' not have a 'user' object?
user = UserID.from_string(builder.user_id)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
if membership == Membership.JOIN:
@ -216,6 +211,25 @@ class MessageHandler(BaseHandler):
event, context = yield self._create_new_client_event(
builder=builder,
)
defer.returnValue((event, context))
@defer.inlineCallbacks
def send_event(self, event, context, ratelimit=True, is_guest=False):
"""
Persists and notifies local clients and federation of an event.
Args:
event (FrozenEvent) the event to send.
context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest.
"""
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if ratelimit:
self.ratelimit(event.sender)
if event.is_state():
prev_state = context.current_state.get((event.type, event.state_key))
@ -229,7 +243,7 @@ class MessageHandler(BaseHandler):
if event.type == EventTypes.Member:
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context, is_guest=is_guest)
yield member_handler.send_membership_event(event, context, is_guest=is_guest)
else:
yield self.handle_new_client_event(
event=event,
@ -241,6 +255,25 @@ class MessageHandler(BaseHandler):
with PreserveLoggingContext():
presence.bump_presence_active_time(user)
@defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True,
token_id=None, txn_id=None, is_guest=False):
"""
Creates an event, then sends it.
See self.create_event and self.send_event.
"""
event, context = yield self.create_event(
event_dict,
token_id=token_id,
txn_id=txn_id
)
yield self.send_event(
event,
context,
ratelimit=ratelimit,
is_guest=is_guest
)
defer.returnValue(event)
@defer.inlineCallbacks

View file

@ -53,7 +53,8 @@ class RegistrationHandler(BaseHandler):
raise SynapseError(
400,
"User ID must only contain characters which do not"
" require URL encoding."
" require URL encoding.",
Codes.INVALID_USERNAME
)
user = UserID(localpart, self.hs.hostname)

View file

@ -22,7 +22,7 @@ from synapse.types import UserID, RoomAlias, RoomID
from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset,
)
from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.api.errors import AuthError, StoreError, SynapseError, Codes
from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor
@ -397,7 +397,58 @@ class RoomMemberHandler(BaseHandler):
remotedomains.add(member.domain)
@defer.inlineCallbacks
def change_membership(self, event, context, do_auth=True, is_guest=False):
def update_membership(self, requester, target, room_id, action, txn_id=None):
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
elif action == "forget":
effective_membership_state = "leave"
msg_handler = self.hs.get_handlers().message_handler
content = {"membership": unicode(effective_membership_state)}
if requester.is_guest:
content["kind"] = "guest"
event, context = yield msg_handler.create_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"state_key": target.to_string(),
},
token_id=requester.access_token_id,
txn_id=txn_id,
)
old_state = context.current_state.get((EventTypes.Member, event.state_key))
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
403,
"Cannot unban user who was not banned (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE
)
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
"Cannot %s user who was is banned" % (action,),
errcode=Codes.BAD_STATE
)
yield msg_handler.send_event(
event,
context,
ratelimit=True,
is_guest=requester.is_guest
)
if action == "forget":
yield self.forget(requester.user, room_id)
@defer.inlineCallbacks
def send_membership_event(self, event, context, is_guest=False):
""" Change the membership status of a user in a room.
Args:
@ -432,7 +483,7 @@ class RoomMemberHandler(BaseHandler):
if not is_guest_access_allowed:
raise AuthError(403, "Guest access not allowed")
yield self._do_join(event, context, do_auth=do_auth)
yield self._do_join(event, context)
else:
if event.membership == Membership.LEAVE:
is_host_in_room = yield self.is_host_in_room(room_id, context)
@ -459,9 +510,7 @@ class RoomMemberHandler(BaseHandler):
yield self._do_local_membership_update(
event,
membership=event.content["membership"],
context=context,
do_auth=do_auth,
)
if prev_state and prev_state.membership == Membership.JOIN:
@ -497,12 +546,12 @@ class RoomMemberHandler(BaseHandler):
})
event, context = yield self._create_new_client_event(builder)
yield self._do_join(event, context, room_hosts=hosts, do_auth=True)
yield self._do_join(event, context, room_hosts=hosts)
defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks
def _do_join(self, event, context, room_hosts=None, do_auth=True):
def _do_join(self, event, context, room_hosts=None):
room_id = event.room_id
# XXX: We don't do an auth check if we are doing an invite
@ -536,9 +585,7 @@ class RoomMemberHandler(BaseHandler):
yield self._do_local_membership_update(
event,
membership=event.content["membership"],
context=context,
do_auth=do_auth,
)
prev_state = context.current_state.get((event.type, event.state_key))
@ -603,8 +650,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(room_ids)
@defer.inlineCallbacks
def _do_local_membership_update(self, event, membership, context,
do_auth):
def _do_local_membership_update(self, event, context):
yield run_on_reactor()
target_user = UserID.from_string(event.state_key)
@ -613,7 +659,6 @@ class RoomMemberHandler(BaseHandler):
event,
context,
extra_users=[target_user],
suppress_auth=(not do_auth),
)
@defer.inlineCallbacks
@ -880,28 +925,39 @@ class RoomContextHandler(BaseHandler):
(excluding state).
Returns:
dict
dict, or None if the event isn't found
"""
before_limit = math.floor(limit/2.)
after_limit = limit - before_limit
now_token = yield self.hs.get_event_sources().get_current_token()
def filter_evts(events):
return self._filter_events_for_client(
user.to_string(),
events,
is_guest=is_guest)
event = yield self.store.get_event(event_id, get_prev_content=True,
allow_none=True)
if not event:
defer.returnValue(None)
return
filtered = yield(filter_evts([event]))
if not filtered:
raise AuthError(
403,
"You don't have permission to access that event."
)
results = yield self.store.get_events_around(
room_id, event_id, before_limit, after_limit
)
results["events_before"] = yield self._filter_events_for_client(
user.to_string(),
results["events_before"],
is_guest=is_guest,
)
results["events_after"] = yield self._filter_events_for_client(
user.to_string(),
results["events_after"],
is_guest=is_guest,
)
results["events_before"] = yield filter_evts(results["events_before"])
results["events_after"] = yield filter_evts(results["events_after"])
results["event"] = event
if results["events_after"]:
last_event_id = results["events_after"][-1].event_id

View file

@ -55,6 +55,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"ephemeral",
"account_data",
"unread_notification_count",
"unread_highlight_count",
])):
__slots__ = []
@ -292,9 +293,14 @@ class SyncHandler(BaseHandler):
notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config, ephemeral_by_room
)
notif_count = None
highlight_count = None
if notifs is not None:
notif_count = len(notifs)
highlight_count = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
current_state = yield self.get_state_at(room_id, now_token)
@ -307,6 +313,7 @@ class SyncHandler(BaseHandler):
room_id, tags_by_room, account_data_by_room
),
unread_notification_count=notif_count,
unread_highlight_count=highlight_count,
))
def account_data_for_user(self, account_data):
@ -529,9 +536,14 @@ class SyncHandler(BaseHandler):
notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config, all_ephemeral_by_room
)
notif_count = None
highlight_count = None
if notifs is not None:
notif_count = len(notifs)
highlight_count = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
just_joined = yield self.check_joined_room(sync_config, state)
if just_joined:
@ -553,7 +565,8 @@ class SyncHandler(BaseHandler):
account_data=self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
),
unread_notification_count=notif_count
unread_notification_count=notif_count,
unread_highlight_count=highlight_count,
)
logger.debug("Result for room %s: %r", room_id, room_sync)
@ -575,7 +588,8 @@ class SyncHandler(BaseHandler):
for room_id in joined_room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token,
ephemeral_by_room, tags_by_room, account_data_by_room
ephemeral_by_room, tags_by_room, account_data_by_room,
all_ephemeral_by_room=all_ephemeral_by_room,
)
if room_sync:
joined.append(room_sync)
@ -655,7 +669,8 @@ class SyncHandler(BaseHandler):
def incremental_sync_with_gap_for_room(self, room_id, sync_config,
since_token, now_token,
ephemeral_by_room, tags_by_room,
account_data_by_room):
account_data_by_room,
all_ephemeral_by_room):
""" Get the incremental delta needed to bring the client up to date for
the room. Gives the client the most recent events and the changes to
state.
@ -671,7 +686,7 @@ class SyncHandler(BaseHandler):
room_id, sync_config, now_token, since_token,
)
logging.debug("Recents %r", batch)
logger.debug("Recents %r", batch)
current_state = yield self.get_state_at(room_id, now_token)
@ -690,11 +705,16 @@ class SyncHandler(BaseHandler):
state = yield self.get_state_at(room_id, now_token)
notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config, ephemeral_by_room
room_id, sync_config, all_ephemeral_by_room
)
notif_count = None
highlight_count = None
if notifs is not None:
notif_count = len(notifs)
highlight_count = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
room_sync = JoinedSyncResult(
room_id=room_id,
@ -705,6 +725,7 @@ class SyncHandler(BaseHandler):
room_id, tags_by_room, account_data_by_room
),
unread_notification_count=notif_count,
unread_highlight_count=highlight_count,
)
logger.debug("Room sync: %r", room_sync)
@ -734,7 +755,7 @@ class SyncHandler(BaseHandler):
leave_event.room_id, sync_config, leave_token, since_token,
)
logging.debug("Recents %r", batch)
logger.debug("Recents %r", batch)
state_events_at_leave = yield self.store.get_state_for_event(
leave_event.event_id
@ -850,8 +871,19 @@ class SyncHandler(BaseHandler):
notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id
)
else:
# There is no new information in this period, so your notification
# count is whatever it was last time.
defer.returnValue(None)
defer.returnValue(notifs)
defer.returnValue(notifs)
# There is no new information in this period, so your notification
# count is whatever it was last time.
defer.returnValue(None)
def _action_has_highlight(actions):
for action in actions:
try:
if action.get("set_tweak", None) == "highlight":
return action.get("value", True)
except AttributeError:
pass
return False

View file

@ -37,7 +37,7 @@ class Pusher(object):
MAX_BACKOFF = 60 * 60 * 1000
GIVE_UP_AFTER = 24 * 60 * 60 * 1000
def __init__(self, _hs, profile_tag, user_name, app_id,
def __init__(self, _hs, profile_tag, user_id, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since):
self.hs = _hs
@ -45,7 +45,7 @@ class Pusher(object):
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.profile_tag = profile_tag
self.user_name = user_name
self.user_id = user_id
self.app_id = app_id
self.app_display_name = app_display_name
self.device_display_name = device_display_name
@ -95,14 +95,14 @@ class Pusher(object):
# we fail to dispatch the push)
config = PaginationConfig(from_token=None, limit='1')
chunk = yield self.evStreamHandler.get_stream(
self.user_name, config, timeout=0, affect_presence=False
self.user_id, config, timeout=0, affect_presence=False
)
self.last_token = chunk['end']
self.store.update_pusher_last_token(
self.app_id, self.pushkey, self.user_name, self.last_token
self.app_id, self.pushkey, self.user_id, self.last_token
)
logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_name, self.last_token)
self.pushkey, self.user_id, self.last_token)
wait = 0
while self.alive:
@ -127,7 +127,7 @@ class Pusher(object):
config = PaginationConfig(from_token=from_tok, limit='1')
timeout = (300 + random.randint(-60, 60)) * 1000
chunk = yield self.evStreamHandler.get_stream(
self.user_name, config, timeout=timeout, affect_presence=False
self.user_id, config, timeout=timeout, affect_presence=False
)
# limiting to 1 may get 1 event plus 1 presence event, so
@ -144,7 +144,7 @@ class Pusher(object):
if read_receipt:
for receipt_part in read_receipt['content'].values():
if 'm.read' in receipt_part:
if self.user_name in receipt_part['m.read'].keys():
if self.user_id in receipt_part['m.read'].keys():
have_updated_badge = True
if not single_event:
@ -154,7 +154,7 @@ class Pusher(object):
yield self.store.update_pusher_last_token(
self.app_id,
self.pushkey,
self.user_name,
self.user_id,
self.last_token
)
return
@ -165,8 +165,8 @@ class Pusher(object):
processed = False
rule_evaluator = yield \
push_rule_evaluator.evaluator_for_user_name_and_profile_tag(
self.user_name, self.profile_tag, single_event['room_id'], self.store
push_rule_evaluator.evaluator_for_user_id_and_profile_tag(
self.user_id, self.profile_tag, single_event['room_id'], self.store
)
actions = yield rule_evaluator.actions_for_event(single_event)
@ -192,7 +192,7 @@ class Pusher(object):
pk
)
yield self.hs.get_pusherpool().remove_pusher(
self.app_id, pk, self.user_name
self.app_id, pk, self.user_id
)
else:
if have_updated_badge:
@ -208,7 +208,7 @@ class Pusher(object):
yield self.store.update_pusher_last_token_and_success(
self.app_id,
self.pushkey,
self.user_name,
self.user_id,
self.last_token,
self.clock.time_msec()
)
@ -217,7 +217,7 @@ class Pusher(object):
yield self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.user_id,
self.failing_since)
else:
if not self.failing_since:
@ -225,7 +225,7 @@ class Pusher(object):
yield self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.user_id,
self.failing_since
)
@ -237,13 +237,13 @@ class Pusher(object):
# of old notifications.
logger.warn("Giving up on a notification to user %s, "
"pushkey %s",
self.user_name, self.pushkey)
self.user_id, self.pushkey)
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
yield self.store.update_pusher_last_token(
self.app_id,
self.pushkey,
self.user_name,
self.user_id,
self.last_token
)
@ -251,14 +251,14 @@ class Pusher(object):
yield self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.user_id,
self.failing_since
)
else:
logger.warn("Failed to dispatch push for user %s "
"(failing for %dms)."
"Trying again in %dms",
self.user_name,
self.user_id,
self.clock.time_msec() - self.failing_since,
self.backoff_delay)
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
@ -299,11 +299,11 @@ class Pusher(object):
membership_list = (Membership.INVITE, Membership.JOIN)
room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=self.user_name,
user_id=self.user_id,
membership_list=membership_list
)
user_is_guest = yield self.store.is_guest(UserID.from_string(self.user_name))
user_is_guest = yield self.store.is_guest(self.user_id)
# XXX: importing inside method to break circular dependency.
# should sort out the mess by moving all this logic out of
@ -311,7 +311,7 @@ class Pusher(object):
# handler to somewhere more amenable to re-use.
from synapse.handlers.sync import SyncConfig
sync_config = SyncConfig(
user=UserID.from_string(self.user_name),
user=UserID.from_string(self.user_id),
filter=FilterCollection({}),
is_guest=user_is_guest,
)
@ -328,13 +328,13 @@ class Pusher(object):
badge += 1
else:
last_unread_event_id = sync_handler.last_read_event_id_for_room_and_user(
r.room_id, self.user_name, ephemeral_by_room
r.room_id, self.user_id, ephemeral_by_room
)
if last_unread_event_id:
notifs = yield (
self.store.get_unread_event_push_actions_by_room_for_user(
r.room_id, self.user_name, last_unread_event_id
r.room_id, self.user_id, last_unread_event_id
)
)
badge += len(notifs)

View file

@ -25,8 +25,9 @@ logger = logging.getLogger(__name__)
class ActionGenerator:
def __init__(self, store):
self.store = store
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
# really we want to get all user ids and all profile tags too,
# since we want the actions for each profile tag for every user and
# also actions for a client with no profile tag for each user.
@ -42,7 +43,7 @@ class ActionGenerator:
)
bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id(
event.room_id, self.store
event.room_id, self.hs, self.store
)
actions_by_user = yield bulk_evaluator.action_for_event_by_user(event, handler)

View file

@ -15,27 +15,25 @@
from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
def list_with_base_rules(rawrules, user_name):
def list_with_base_rules(rawrules):
ruleslist = []
# shove the server default rules for each kind onto the end of each
current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
ruleslist.extend(make_base_prepend_rules(
user_name, PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
for r in rawrules:
if r['priority_class'] < current_prio_class:
while r['priority_class'] < current_prio_class:
ruleslist.extend(make_base_append_rules(
user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
current_prio_class -= 1
if current_prio_class > 0:
ruleslist.extend(make_base_prepend_rules(
user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
@ -43,223 +41,232 @@ def list_with_base_rules(rawrules, user_name):
while current_prio_class > 0:
ruleslist.extend(make_base_append_rules(
user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
current_prio_class -= 1
if current_prio_class > 0:
ruleslist.extend(make_base_prepend_rules(
user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
return ruleslist
def make_base_append_rules(user, kind):
def make_base_append_rules(kind):
rules = []
if kind == 'override':
rules = make_base_append_override_rules()
rules = BASE_APPEND_OVRRIDE_RULES
elif kind == 'underride':
rules = make_base_append_underride_rules(user)
rules = BASE_APPEND_UNDERRIDE_RULES
elif kind == 'content':
rules = make_base_append_content_rules(user)
for r in rules:
r['priority_class'] = PRIORITY_CLASS_MAP[kind]
r['default'] = True # Deprecated, left for backwards compat
rules = BASE_APPEND_CONTENT_RULES
return rules
def make_base_prepend_rules(user, kind):
def make_base_prepend_rules(kind):
rules = []
if kind == 'override':
rules = make_base_prepend_override_rules()
for r in rules:
r['priority_class'] = PRIORITY_CLASS_MAP[kind]
r['default'] = True # Deprecated, left for backwards compat
rules = BASE_PREPEND_OVERRIDE_RULES
return rules
def make_base_append_content_rules(user):
return [
{
'rule_id': 'global/content/.m.rule.contains_user_name',
'conditions': [
{
'kind': 'event_match',
'key': 'content.body',
'pattern': user.localpart, # Matrix ID match
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default',
}, {
'set_tweak': 'highlight'
}
]
},
]
BASE_APPEND_CONTENT_RULES = [
{
'rule_id': 'global/content/.m.rule.contains_user_name',
'conditions': [
{
'kind': 'event_match',
'key': 'content.body',
'pattern_type': 'user_localpart'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default',
}, {
'set_tweak': 'highlight'
}
]
},
]
def make_base_prepend_override_rules():
return [
{
'rule_id': 'global/override/.m.rule.master',
'enabled': False,
'conditions': [],
'actions': [
"dont_notify"
]
}
]
BASE_PREPEND_OVERRIDE_RULES = [
{
'rule_id': 'global/override/.m.rule.master',
'enabled': False,
'conditions': [],
'actions': [
"dont_notify"
]
}
]
def make_base_append_override_rules():
return [
{
'rule_id': 'global/override/.m.rule.suppress_notices',
'conditions': [
{
'kind': 'event_match',
'key': 'content.msgtype',
'pattern': 'm.notice',
}
],
'actions': [
'dont_notify',
]
}
]
BASE_APPEND_OVRRIDE_RULES = [
{
'rule_id': 'global/override/.m.rule.suppress_notices',
'conditions': [
{
'kind': 'event_match',
'key': 'content.msgtype',
'pattern': 'm.notice',
'_id': '_suppress_notices',
}
],
'actions': [
'dont_notify',
]
}
]
def make_base_append_underride_rules(user):
return [
{
'rule_id': 'global/underride/.m.rule.call',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.call.invite',
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'ring'
}, {
'set_tweak': 'highlight',
'value': False
}
]
},
{
'rule_id': 'global/underride/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight'
}
]
},
{
'rule_id': 'global/underride/.m.rule.room_one_to_one',
'conditions': [
{
'kind': 'room_member_count',
'is': '2'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight',
'value': False
}
]
},
{
'rule_id': 'global/underride/.m.rule.invite_for_me',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
},
{
'kind': 'event_match',
'key': 'content.membership',
'pattern': 'invite',
},
{
'kind': 'event_match',
'key': 'state_key',
'pattern': user.to_string(),
},
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight',
'value': False
}
]
},
{
'rule_id': 'global/underride/.m.rule.member_event',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
}
],
'actions': [
'notify', {
'set_tweak': 'highlight',
'value': False
}
]
},
{
'rule_id': 'global/underride/.m.rule.message',
'enabled': False,
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.message',
}
],
'actions': [
'notify', {
'set_tweak': 'highlight',
'value': False
}
]
}
]
BASE_APPEND_UNDERRIDE_RULES = [
{
'rule_id': 'global/underride/.m.rule.call',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.call.invite',
'_id': '_call',
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'ring'
}, {
'set_tweak': 'highlight',
'value': False
}
]
},
{
'rule_id': 'global/underride/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight'
}
]
},
{
'rule_id': 'global/underride/.m.rule.room_one_to_one',
'conditions': [
{
'kind': 'room_member_count',
'is': '2',
'_id': 'member_count',
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight',
'value': False
}
]
},
{
'rule_id': 'global/underride/.m.rule.invite_for_me',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
'_id': '_member',
},
{
'kind': 'event_match',
'key': 'content.membership',
'pattern': 'invite',
'_id': '_invite_member',
},
{
'kind': 'event_match',
'key': 'state_key',
'pattern_type': 'user_id'
},
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight',
'value': False
}
]
},
{
'rule_id': 'global/underride/.m.rule.member_event',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
'_id': '_member',
}
],
'actions': [
'notify', {
'set_tweak': 'highlight',
'value': False
}
]
},
{
'rule_id': 'global/underride/.m.rule.message',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.message',
'_id': '_message',
}
],
'actions': [
'notify', {
'set_tweak': 'highlight',
'value': False
}
]
}
]
for r in BASE_APPEND_CONTENT_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['content']
r['default'] = True
for r in BASE_PREPEND_OVERRIDE_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['override']
r['default'] = True
for r in BASE_APPEND_OVRRIDE_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['override']
r['default'] = True
for r in BASE_APPEND_UNDERRIDE_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['underride']
r['default'] = True

View file

@ -14,16 +14,15 @@
# limitations under the License.
import logging
import simplejson as json
import ujson as json
from twisted.internet import defer
from synapse.types import UserID
import baserules
from push_rule_evaluator import PushRuleEvaluator
from push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes
from synapse.events.utils import serialize_event
logger = logging.getLogger(__name__)
@ -35,28 +34,30 @@ def decode_rule_json(rule):
@defer.inlineCallbacks
def evaluator_for_room_id(room_id, store):
users = yield store.get_users_in_room(room_id)
rules_by_user = yield store.bulk_get_push_rules(users)
def _get_rules(room_id, user_ids, store):
rules_by_user = yield store.bulk_get_push_rules(user_ids)
rules_by_user = {
uid: baserules.list_with_base_rules(
[decode_rule_json(rule_list) for rule_list in rules_by_user[uid]]
if uid in rules_by_user else [],
UserID.from_string(uid),
)
for uid in users
uid: baserules.list_with_base_rules([
decode_rule_json(rule_list)
for rule_list in rules_by_user.get(uid, [])
])
for uid in user_ids
}
member_events = yield store.get_current_state(
room_id=room_id,
event_type='m.room.member',
)
display_names = {}
for ev in member_events:
if ev.content.get("displayname"):
display_names[ev.state_key] = ev.content.get("displayname")
defer.returnValue(rules_by_user)
@defer.inlineCallbacks
def evaluator_for_room_id(room_id, hs, store):
results = yield store.get_receipts_for_room(room_id, "m.read")
user_ids = [
row["user_id"] for row in results
if hs.is_mine_id(row["user_id"])
]
rules_by_user = yield _get_rules(room_id, user_ids, store)
defer.returnValue(BulkPushRuleEvaluator(
room_id, rules_by_user, display_names, users, store
room_id, rules_by_user, user_ids, store
))
@ -69,10 +70,9 @@ class BulkPushRuleEvaluator:
the same logic to run the actual rules, but could be optimised further
(see https://matrix.org/jira/browse/SYN-562)
"""
def __init__(self, room_id, rules_by_user, display_names, users_in_room, store):
def __init__(self, room_id, rules_by_user, users_in_room, store):
self.room_id = room_id
self.rules_by_user = rules_by_user
self.display_names = display_names
self.users_in_room = users_in_room
self.store = store
@ -80,15 +80,30 @@ class BulkPushRuleEvaluator:
def action_for_event_by_user(self, event, handler):
actions_by_user = {}
for uid, rules in self.rules_by_user.items():
display_name = None
if uid in self.display_names:
display_name = self.display_names[uid]
users_dict = yield self.store.are_guests(self.rules_by_user.keys())
is_guest = yield self.store.is_guest(UserID.from_string(uid))
filtered = yield handler._filter_events_for_client(
uid, [event], is_guest=is_guest
)
filtered_by_user = yield handler._filter_events_for_clients(
users_dict.items(), [event]
)
evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room))
condition_cache = {}
member_state = yield self.store.get_state_for_event(
event.event_id,
)
display_names = {}
for ev in member_state.values():
nm = ev.content.get("displayname", None)
if nm and ev.type == EventTypes.Member:
display_names[ev.state_key] = nm
for uid, rules in self.rules_by_user.items():
display_name = display_names.get(uid, None)
filtered = filtered_by_user[uid]
if len(filtered) == 0:
continue
@ -96,29 +111,32 @@ class BulkPushRuleEvaluator:
if 'enabled' in rule and not rule['enabled']:
continue
# XXX: profile tags
if BulkPushRuleEvaluator.event_matches_rule(
event, rule,
display_name, len(self.users_in_room), None
):
matches = _condition_checker(
evaluator, rule['conditions'], uid, display_name, condition_cache
)
if matches:
actions = [x for x in rule['actions'] if x != 'dont_notify']
if len(actions) > 0:
if actions:
actions_by_user[uid] = actions
break
defer.returnValue(actions_by_user)
@staticmethod
def event_matches_rule(event, rule,
display_name, room_member_count, profile_tag):
matches = True
# passing the clock all the way into here is extremely awkward and push
# rules do not care about any of the relative timestamps, so we just
# pass 0 for the current time.
client_event = serialize_event(event, 0)
def _condition_checker(evaluator, conditions, uid, display_name, cache):
for cond in conditions:
_id = cond.get("_id", None)
if _id:
res = cache.get(_id, None)
if res is False:
return False
elif res is True:
continue
for cond in rule['conditions']:
matches &= PushRuleEvaluator._event_fulfills_condition(
client_event, cond, display_name, room_member_count, profile_tag
)
return matches
res = evaluator.matches(cond, uid, display_name, None)
if _id:
cache[_id] = bool(res)
if not res:
return False
return True

View file

@ -23,13 +23,13 @@ logger = logging.getLogger(__name__)
class HttpPusher(Pusher):
def __init__(self, _hs, profile_tag, user_name, app_id,
def __init__(self, _hs, profile_tag, user_id, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since):
super(HttpPusher, self).__init__(
_hs,
profile_tag,
user_name,
user_id,
app_id,
app_display_name,
device_display_name,
@ -87,7 +87,7 @@ class HttpPusher(Pusher):
}
if event['type'] == 'm.room.member':
d['notification']['membership'] = event['content']['membership']
d['notification']['user_is_target'] = event['state_key'] == self.user_name
d['notification']['user_is_target'] = event['state_key'] == self.user_id
if 'content' in event:
d['notification']['content'] = event['content']
@ -117,7 +117,7 @@ class HttpPusher(Pusher):
@defer.inlineCallbacks
def send_badge(self, badge):
logger.info("Sending updated badge count %d to %r", badge, self.user_name)
logger.info("Sending updated badge count %d to %r", badge, self.user_id)
d = {
'notification': {
'id': '',

View file

@ -15,40 +15,71 @@
from twisted.internet import defer
from synapse.types import UserID
import baserules
import logging
import simplejson as json
import re
from synapse.types import UserID
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
GLOB_REGEX = re.compile(r'\\\[(\\\!|)(.*)\\\]')
IS_GLOB = re.compile(r'[\?\*\[\]]')
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
@defer.inlineCallbacks
def evaluator_for_user_name_and_profile_tag(user_name, profile_tag, room_id, store):
rawrules = yield store.get_push_rules_for_user(user_name)
enabled_map = yield store.get_push_rules_enabled_for_user(user_name)
def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
rawrules = yield store.get_push_rules_for_user(user_id)
enabled_map = yield store.get_push_rules_enabled_for_user(user_id)
our_member_event = yield store.get_current_state(
room_id=room_id,
event_type='m.room.member',
state_key=user_name,
state_key=user_id,
)
defer.returnValue(PushRuleEvaluator(
user_name, profile_tag, rawrules, enabled_map,
user_id, profile_tag, rawrules, enabled_map,
room_id, our_member_event, store
))
def _room_member_count(ev, condition, room_member_count):
if 'is' not in condition:
return False
m = INEQUALITY_EXPR.match(condition['is'])
if not m:
return False
ineq = m.group(1)
rhs = m.group(2)
if not rhs.isdigit():
return False
rhs = int(rhs)
if ineq == '' or ineq == '==':
return room_member_count == rhs
elif ineq == '<':
return room_member_count < rhs
elif ineq == '>':
return room_member_count > rhs
elif ineq == '>=':
return room_member_count >= rhs
elif ineq == '<=':
return room_member_count <= rhs
else:
return False
class PushRuleEvaluator:
DEFAULT_ACTIONS = []
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def __init__(self, user_name, profile_tag, raw_rules, enabled_map, room_id,
def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id,
our_member_event, store):
self.user_name = user_name
self.user_id = user_id
self.profile_tag = profile_tag
self.room_id = room_id
self.our_member_event = our_member_event
@ -61,8 +92,7 @@ class PushRuleEvaluator:
rule['actions'] = json.loads(raw_rule['actions'])
rules.append(rule)
user = UserID.from_string(self.user_name)
self.rules = baserules.list_with_base_rules(rules, user)
self.rules = baserules.list_with_base_rules(rules)
self.enabled_map = enabled_map
@ -83,7 +113,7 @@ class PushRuleEvaluator:
has configured both globally and per-room when we have the ability
to do such things.
"""
if ev['user_id'] == self.user_name:
if ev['user_id'] == self.user_id:
# let's assume you probably know about messages you sent yourself
defer.returnValue([])
@ -98,39 +128,44 @@ class PushRuleEvaluator:
room_members = yield self.store.get_users_in_room(room_id)
room_member_count = len(room_members)
evaluator = PushRuleEvaluatorForEvent(ev, room_member_count)
for r in self.rules:
if r['rule_id'] in self.enabled_map:
r['enabled'] = self.enabled_map[r['rule_id']]
elif 'enabled' not in r:
r['enabled'] = True
if not r['enabled']:
enabled = self.enabled_map.get(r['rule_id'], None)
if enabled is not None and not enabled:
continue
if not r.get("enabled", True):
continue
matches = True
conditions = r['conditions']
actions = r['actions']
for c in conditions:
matches &= self._event_fulfills_condition(
ev, c, display_name=my_display_name,
room_member_count=room_member_count,
profile_tag=self.profile_tag
)
logger.debug(
"Rule %s %s",
r['rule_id'], "matches" if matches else "doesn't match"
)
# ignore rules with no actions (we have an explict 'dont_notify')
if len(actions) == 0:
logger.warn(
"Ignoring rule id %s with no actions for user %s",
r['rule_id'], self.user_name
r['rule_id'], self.user_id
)
continue
matches = True
for c in conditions:
matches = evaluator.matches(
c, self.user_id, my_display_name, self.profile_tag
)
if not matches:
break
logger.debug(
"Rule %s %s",
r['rule_id'], "matches" if matches else "doesn't match"
)
if matches:
logger.info(
logger.debug(
"%s matches for user %s, event %s",
r['rule_id'], self.user_name, ev['event_id']
r['rule_id'], self.user_id, ev['event_id']
)
# filter out dont_notify as we treat an empty actions list
@ -139,94 +174,149 @@ class PushRuleEvaluator:
defer.returnValue(actions)
logger.info(
logger.debug(
"No rules match for user %s, event %s",
self.user_name, ev['event_id']
self.user_id, ev['event_id']
)
defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS)
@staticmethod
def _glob_to_regexp(glob):
r = re.escape(glob)
r = re.sub(r'\\\*', r'.*?', r)
r = re.sub(r'\\\?', r'.', r)
# handle [abc], [a-z] and [!a-z] style ranges.
r = re.sub(r'\\\[(\\\!|)(.*)\\\]',
lambda x: ('[%s%s]' % (x.group(1) and '^' or '',
re.sub(r'\\\-', '-', x.group(2)))), r)
return r
class PushRuleEvaluatorForEvent(object):
def __init__(self, event, room_member_count):
self._event = event
self._room_member_count = room_member_count
@staticmethod
def _event_fulfills_condition(ev, condition,
display_name, room_member_count, profile_tag):
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)
def matches(self, condition, user_id, display_name, profile_tag):
if condition['kind'] == 'event_match':
if 'pattern' not in condition:
logger.warn("event_match condition with no pattern")
return False
# XXX: optimisation: cache our pattern regexps
if condition['key'] == 'content.body':
r = r'\b%s\b' % PushRuleEvaluator._glob_to_regexp(condition['pattern'])
else:
r = r'^%s$' % PushRuleEvaluator._glob_to_regexp(condition['pattern'])
val = _value_for_dotted_key(condition['key'], ev)
if val is None:
return False
return re.search(r, val, flags=re.IGNORECASE) is not None
return self._event_match(condition, user_id)
elif condition['kind'] == 'device':
if 'profile_tag' not in condition:
return True
return condition['profile_tag'] == profile_tag
elif condition['kind'] == 'contains_display_name':
# This is special because display names can be different
# between rooms and so you can't really hard code it in a rule.
# Optimisation: we should cache these names and update them from
# the event stream.
if 'content' not in ev or 'body' not in ev['content']:
return False
if not display_name:
return False
return re.search(
r"\b%s\b" % re.escape(display_name), ev['content']['body'],
flags=re.IGNORECASE
) is not None
return self._contains_display_name(display_name)
elif condition['kind'] == 'room_member_count':
if 'is' not in condition:
return False
m = PushRuleEvaluator.INEQUALITY_EXPR.match(condition['is'])
if not m:
return False
ineq = m.group(1)
rhs = m.group(2)
if not rhs.isdigit():
return False
rhs = int(rhs)
if ineq == '' or ineq == '==':
return room_member_count == rhs
elif ineq == '<':
return room_member_count < rhs
elif ineq == '>':
return room_member_count > rhs
elif ineq == '>=':
return room_member_count >= rhs
elif ineq == '<=':
return room_member_count <= rhs
else:
return False
return _room_member_count(
self._event, condition, self._room_member_count
)
else:
return True
def _event_match(self, condition, user_id):
pattern = condition.get('pattern', None)
def _value_for_dotted_key(dotted_key, event):
parts = dotted_key.split(".")
val = event
while len(parts) > 0:
if parts[0] not in val:
return None
val = val[parts[0]]
parts = parts[1:]
return val
if not pattern:
pattern_type = condition.get('pattern_type', None)
if pattern_type == "user_id":
pattern = user_id
elif pattern_type == "user_localpart":
pattern = UserID.from_string(user_id).localpart
if not pattern:
logger.warn("event_match condition with no pattern")
return False
# XXX: optimisation: cache our pattern regexps
if condition['key'] == 'content.body':
body = self._event["content"].get("body", None)
if not body:
return False
return _glob_matches(pattern, body, word_boundary=True)
else:
haystack = self._get_value(condition['key'])
if haystack is None:
return False
return _glob_matches(pattern, haystack)
def _contains_display_name(self, display_name):
if not display_name:
return False
body = self._event["content"].get("body", None)
if not body:
return False
return _glob_matches(display_name, body, word_boundary=True)
def _get_value(self, dotted_key):
return self._value_cache.get(dotted_key, None)
def _glob_matches(glob, value, word_boundary=False):
"""Tests if value matches glob.
Args:
glob (string)
value (string): String to test against glob.
word_boundary (bool): Whether to match against word boundaries or entire
string. Defaults to False.
Returns:
bool
"""
try:
if IS_GLOB.search(glob):
r = re.escape(glob)
r = r.replace(r'\*', '.*?')
r = r.replace(r'\?', '.')
# handle [abc], [a-z] and [!a-z] style ranges.
r = GLOB_REGEX.sub(
lambda x: (
'[%s%s]' % (
x.group(1) and '^' or '',
x.group(2).replace(r'\\\-', '-')
)
),
r,
)
if word_boundary:
r = r"\b%s\b" % (r,)
r = _compile_regex(r)
return r.search(value)
else:
r = r + "$"
r = _compile_regex(r)
return r.match(value)
elif word_boundary:
r = re.escape(glob)
r = r"\b%s\b" % (r,)
r = _compile_regex(r)
return r.search(value)
else:
return value.lower() == glob.lower()
except re.error:
logger.warn("Failed to parse glob to regex: %r", glob)
return False
def _flatten_dict(d, prefix=[], result={}):
for key, value in d.items():
if isinstance(value, basestring):
result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"):
_flatten_dict(value, prefix=(prefix+[key]), result=result)
return result
regex_cache = LruCache(5000)
def _compile_regex(regex_str):
r = regex_cache.get(regex_str, None)
if r:
return r
r = re.compile(regex_str, flags=re.IGNORECASE)
regex_cache[regex_str] = r
return r

View file

@ -37,14 +37,14 @@ class PusherPool:
self._start_pushers(pushers)
@defer.inlineCallbacks
def add_pusher(self, user_name, access_token, profile_tag, kind, app_id,
def add_pusher(self, user_id, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, pushkey, lang, data):
# we try to create the pusher just to validate the config: it
# will then get pulled out of the database,
# recreated, added and started: this means we have only one
# code path adding pushers.
self._create_pusher({
"user_name": user_name,
"user_name": user_id,
"kind": kind,
"profile_tag": profile_tag,
"app_id": app_id,
@ -59,7 +59,7 @@ class PusherPool:
"failing_since": None
})
yield self._add_pusher_to_store(
user_name, access_token, profile_tag, kind, app_id,
user_id, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name,
pushkey, lang, data
)
@ -94,11 +94,11 @@ class PusherPool:
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def _add_pusher_to_store(self, user_name, access_token, profile_tag, kind,
def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, lang, data):
yield self.store.add_pusher(
user_name=user_name,
user_id=user_id,
access_token=access_token,
profile_tag=profile_tag,
kind=kind,
@ -110,14 +110,14 @@ class PusherPool:
lang=lang,
data=data,
)
self._refresh_pusher(app_id, pushkey, user_name)
self._refresh_pusher(app_id, pushkey, user_id)
def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http':
return HttpPusher(
self.hs,
profile_tag=pusherdict['profile_tag'],
user_name=pusherdict['user_name'],
user_id=pusherdict['user_name'],
app_id=pusherdict['app_id'],
app_display_name=pusherdict['app_display_name'],
device_display_name=pusherdict['device_display_name'],
@ -135,14 +135,14 @@ class PusherPool:
)
@defer.inlineCallbacks
def _refresh_pusher(self, app_id, pushkey, user_name):
def _refresh_pusher(self, app_id, pushkey, user_id):
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(
app_id, pushkey
)
p = None
for r in resultlist:
if r['user_name'] == user_name:
if r['user_name'] == user_id:
p = r
if p:
@ -171,12 +171,12 @@ class PusherPool:
logger.info("Started pushers")
@defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey, user_name):
fullid = "%s:%s:%s" % (app_id, pushkey, user_name)
def remove_pusher(self, app_id, pushkey, user_id):
fullid = "%s:%s:%s" % (app_id, pushkey, user_id)
if fullid in self.pushers:
logger.info("Stopping pusher %s", fullid)
self.pushers[fullid].stop()
del self.pushers[fullid]
yield self.store.delete_pusher_by_app_id_pushkey_user_name(
app_id, pushkey, user_name
yield self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id
)

View file

@ -27,6 +27,7 @@ from synapse.push.rulekinds import (
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
)
import copy
import simplejson as json
@ -51,7 +52,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
content = _parse_json(request)
if 'attr' in spec:
self.set_rule_attr(requester.user, spec, content)
self.set_rule_attr(requester.user.to_string(), spec, content)
defer.returnValue((200, {}))
try:
@ -73,7 +74,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
try:
yield self.hs.get_datastore().add_push_rule(
user_name=requester.user.to_string(),
user_id=requester.user.to_string(),
rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class,
conditions=conditions,
@ -126,7 +127,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
ruleslist = baserules.list_with_base_rules(ruleslist, user)
# We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist))
rules = {'global': {}, 'device': {}}
@ -140,6 +142,16 @@ class PushRuleRestServlet(ClientV1RestServlet):
template_name = _priority_class_to_template_name(r['priority_class'])
# Remove internal stuff.
for c in r["conditions"]:
c.pop("_id", None)
pattern_type = c.pop("pattern_type", None)
if pattern_type == "user_id":
c["pattern"] = user.to_string()
elif pattern_type == "user_localpart":
c["pattern"] = user.localpart
if r['priority_class'] > PRIORITY_CLASS_MAP['override']:
# per-device rule
profile_tag = _profile_tag_from_conditions(r["conditions"])
@ -206,7 +218,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_OPTIONS(self, _):
return 200, {}
def set_rule_attr(self, user_name, spec, val):
def set_rule_attr(self, user_id, spec, val):
if spec['attr'] == 'enabled':
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
@ -217,15 +229,15 @@ class PushRuleRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
self.hs.get_datastore().set_push_rule_enabled(
user_name, namespaced_rule_id, val
user_id, namespaced_rule_id, val
)
else:
raise UnrecognizedRequestError()
def get_rule_attr(self, user_name, namespaced_rule_id, attr):
def get_rule_attr(self, user_id, namespaced_rule_id, attr):
if attr == 'enabled':
return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id(
user_name, namespaced_rule_id
user_id, namespaced_rule_id
)
else:
raise UnrecognizedRequestError()

View file

@ -41,7 +41,7 @@ class PusherRestServlet(ClientV1RestServlet):
and 'kind' in content and
content['kind'] is None):
yield pusher_pool.remove_pusher(
content['app_id'], content['pushkey'], user_name=user.to_string()
content['app_id'], content['pushkey'], user_id=user.to_string()
)
defer.returnValue((200, {}))
@ -71,7 +71,7 @@ class PusherRestServlet(ClientV1RestServlet):
try:
yield pusher_pool.add_pusher(
user_name=user.to_string(),
user_id=user.to_string(),
access_token=requester.access_token_id,
profile_tag=content['profile_tag'],
kind=content['kind'],

View file

@ -414,10 +414,16 @@ class RoomEventContext(ClientV1RestServlet):
requester.is_guest,
)
if not results:
raise SynapseError(
404, "Event not found.", errcode=Codes.NOT_FOUND
)
time_now = self.clock.time_msec()
results["events_before"] = [
serialize_event(event, time_now) for event in results["events_before"]
]
results["event"] = serialize_event(results["event"], time_now)
results["events_after"] = [
serialize_event(event, time_now) for event in results["events_after"]
]
@ -436,7 +442,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
"(?P<membership_action>join|invite|leave|ban|kick|forget)")
"(?P<membership_action>join|invite|leave|ban|unban|kick|forget)")
register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks
@ -445,9 +451,6 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
request,
allow_guest=True,
)
user = requester.user
effective_membership_action = membership_action
if requester.is_guest and membership_action not in {
Membership.JOIN,
@ -457,13 +460,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
content = _parse_json(request)
# target user is you unless it is an invite
state_key = user.to_string()
if membership_action == "invite" and self._has_3pid_invite_keys(content):
yield self.handlers.room_member_handler.do_3pid_invite(
room_id,
user,
requester.user,
content["medium"],
content["address"],
content["id_server"],
@ -472,42 +472,21 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
)
defer.returnValue((200, {}))
return
elif membership_action in ["invite", "ban", "kick"]:
if "user_id" in content:
state_key = content["user_id"]
else:
target = requester.user
if membership_action in ["invite", "ban", "unban", "kick"]:
if "user_id" not in content:
raise SynapseError(400, "Missing user_id key.")
target = UserID.from_string(content["user_id"])
# make sure it looks like a user ID; it'll throw if it's invalid.
UserID.from_string(state_key)
if membership_action == "kick":
effective_membership_action = "leave"
elif membership_action == "forget":
effective_membership_action = "leave"
msg_handler = self.handlers.message_handler
content = {"membership": unicode(effective_membership_action)}
if requester.is_guest:
content["kind"] = "guest"
yield msg_handler.create_and_send_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": user.to_string(),
"state_key": state_key,
},
token_id=requester.access_token_id,
yield self.handlers.room_member_handler.update_membership(
requester=requester,
target=target,
room_id=room_id,
action=membership_action,
txn_id=txn_id,
is_guest=requester.is_guest,
)
if membership_action == "forget":
yield self.handlers.room_member_handler.forget(user, room_id)
defer.returnValue((200, {}))
def _has_3pid_invite_keys(self, content):

View file

@ -313,6 +313,7 @@ class SyncRestServlet(RestServlet):
ephemeral_events = filter.filter_room_ephemeral(room.ephemeral)
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notification_count"] = room.unread_notification_count
result["unread_highlight_count"] = room.unread_highlight_count
return result

View file

@ -139,6 +139,9 @@ class BaseHomeServer(object):
def is_mine(self, domain_specific_string):
return domain_specific_string.domain == self.hostname
def is_mine_id(self, string):
return string.split(":", 1)[1] == self.hostname
# Build magic accessors for every dependency
for depname in BaseHomeServer.DEPENDENCIES:
BaseHomeServer._make_dependency_method(depname)

View file

@ -15,12 +15,12 @@
import logging
import urllib
import yaml
from simplejson import JSONDecodeError
import simplejson as json
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.appservice import ApplicationService, AppServiceTransaction
from synapse.config._base import ConfigError
from synapse.storage.roommember import RoomsForUser
from synapse.types import UserID
from ._base import SQLBaseStore
@ -144,66 +144,9 @@ class ApplicationServiceStore(SQLBaseStore):
return rooms_for_user_matching_user_id
def _parse_services_dict(self, results):
# SQL results in the form:
# [
# {
# 'regex': "something",
# 'url': "something",
# 'namespace': enum,
# 'as_id': 0,
# 'token': "something",
# 'hs_token': "otherthing",
# 'id': 0
# }
# ]
services = {}
for res in results:
as_token = res["token"]
if as_token is None:
continue
if as_token not in services:
# add the service
services[as_token] = {
"id": res["id"],
"url": res["url"],
"token": as_token,
"hs_token": res["hs_token"],
"sender": res["sender"],
"namespaces": {
ApplicationService.NS_USERS: [],
ApplicationService.NS_ALIASES: [],
ApplicationService.NS_ROOMS: []
}
}
# add the namespace regex if one exists
ns_int = res["namespace"]
if ns_int is None:
continue
try:
services[as_token]["namespaces"][
ApplicationService.NS_LIST[ns_int]].append(
json.loads(res["regex"])
)
except IndexError:
logger.error("Bad namespace enum '%s'. %s", ns_int, res)
except JSONDecodeError:
logger.error("Bad regex object '%s'", res["regex"])
service_list = []
for service in services.values():
service_list.append(ApplicationService(
token=service["token"],
url=service["url"],
namespaces=service["namespaces"],
hs_token=service["hs_token"],
sender=service["sender"],
id=service["id"]
))
return service_list
def _load_appservice(self, as_info):
required_string_fields = [
# TODO: Add id here when it's stable to release
"url", "as_token", "hs_token", "sender_localpart"
]
for field in required_string_fields:
@ -245,7 +188,7 @@ class ApplicationServiceStore(SQLBaseStore):
namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["as_token"] # the token is the only unique thing here
id=as_info["id"] if "id" in as_info else as_info["as_token"],
)
def _populate_appservice_cache(self, config_files):
@ -256,15 +199,38 @@ class ApplicationServiceStore(SQLBaseStore):
)
return
# Dicts of value -> filename
seen_as_tokens = {}
seen_ids = {}
for config_file in config_files:
try:
with open(config_file, 'r') as f:
appservice = self._load_appservice(yaml.load(f))
if appservice.id in seen_ids:
raise ConfigError(
"Cannot reuse ID across application services: "
"%s (files: %s, %s)" % (
appservice.id, config_file, seen_ids[appservice.id],
)
)
seen_ids[appservice.id] = config_file
if appservice.token in seen_as_tokens:
raise ConfigError(
"Cannot reuse as_token across application services: "
"%s (files: %s, %s)" % (
appservice.token,
config_file,
seen_as_tokens[appservice.token],
)
)
seen_as_tokens[appservice.token] = config_file
logger.info("Loaded application service: %s", appservice)
self.services_cache.append(appservice)
except Exception as e:
logger.error("Failed to load appservice from '%s'", config_file)
logger.exception(e)
raise
class ApplicationServiceTransactionStore(SQLBaseStore):

View file

@ -17,7 +17,7 @@ from ._base import SQLBaseStore
from twisted.internet import defer
import logging
import simplejson as json
import ujson as json
logger = logging.getLogger(__name__)
@ -84,7 +84,8 @@ class EventPushActionsStore(SQLBaseStore):
)
)
return [
{"event_id": row[0], "actions": row[1]} for row in txn.fetchall()
{"event_id": row[0], "actions": json.loads(row[1])}
for row in txn.fetchall()
]
ret = yield self.runInteraction(

View file

@ -25,11 +25,11 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore):
@cachedInlineCallbacks()
def get_push_rules_for_user(self, user_name):
def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
table="push_rules",
keyvalues={
"user_name": user_name,
"user_name": user_id,
},
retcols=(
"user_name", "rule_id", "priority_class", "priority",
@ -45,11 +45,11 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rows)
@cachedInlineCallbacks()
def get_push_rules_enabled_for_user(self, user_name):
def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list(
table="push_rules_enable",
keyvalues={
'user_name': user_name
'user_name': user_id
},
retcols=(
"user_name", "rule_id", "enabled",
@ -122,7 +122,7 @@ class PushRuleStore(SQLBaseStore):
)
defer.returnValue(ret)
def _add_push_rule_relative_txn(self, txn, user_name, **kwargs):
def _add_push_rule_relative_txn(self, txn, user_id, **kwargs):
after = kwargs.pop("after", None)
relative_to_rule = kwargs.pop("before", after)
@ -130,7 +130,7 @@ class PushRuleStore(SQLBaseStore):
txn,
table="push_rules",
keyvalues={
"user_name": user_name,
"user_name": user_id,
"rule_id": relative_to_rule,
},
retcols=["priority_class", "priority"],
@ -154,7 +154,7 @@ class PushRuleStore(SQLBaseStore):
new_rule.pop("before", None)
new_rule.pop("after", None)
new_rule['priority_class'] = priority_class
new_rule['user_name'] = user_name
new_rule['user_name'] = user_id
new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
# check if the priority before/after is free
@ -170,7 +170,7 @@ class PushRuleStore(SQLBaseStore):
"SELECT COUNT(*) FROM push_rules"
" WHERE user_name = ? AND priority_class = ? AND priority = ?"
)
txn.execute(sql, (user_name, priority_class, new_rule_priority))
txn.execute(sql, (user_id, priority_class, new_rule_priority))
res = txn.fetchall()
num_conflicting = res[0][0]
@ -187,14 +187,14 @@ class PushRuleStore(SQLBaseStore):
else:
sql += ">= ?"
txn.execute(sql, (user_name, priority_class, new_rule_priority))
txn.execute(sql, (user_id, priority_class, new_rule_priority))
txn.call_after(
self.get_push_rules_for_user.invalidate, (user_name,)
self.get_push_rules_for_user.invalidate, (user_id,)
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)
self._simple_insert_txn(
@ -203,14 +203,14 @@ class PushRuleStore(SQLBaseStore):
values=new_rule,
)
def _add_push_rule_highest_priority_txn(self, txn, user_name,
def _add_push_rule_highest_priority_txn(self, txn, user_id,
priority_class, **kwargs):
# find the highest priority rule in that class
sql = (
"SELECT COUNT(*), MAX(priority) FROM push_rules"
" WHERE user_name = ? and priority_class = ?"
)
txn.execute(sql, (user_name, priority_class))
txn.execute(sql, (user_id, priority_class))
res = txn.fetchall()
(how_many, highest_prio) = res[0]
@ -221,15 +221,15 @@ class PushRuleStore(SQLBaseStore):
# and insert the new rule
new_rule = kwargs
new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
new_rule['user_name'] = user_name
new_rule['user_name'] = user_id
new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio
txn.call_after(
self.get_push_rules_for_user.invalidate, (user_name,)
self.get_push_rules_for_user.invalidate, (user_id,)
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)
self._simple_insert_txn(
@ -239,48 +239,48 @@ class PushRuleStore(SQLBaseStore):
)
@defer.inlineCallbacks
def delete_push_rule(self, user_name, rule_id):
def delete_push_rule(self, user_id, rule_id):
"""
Delete a push rule. Args specify the row to be deleted and can be
any of the columns in the push_rule table, but below are the
standard ones
Args:
user_name (str): The matrix ID of the push rule owner
user_id (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted
"""
yield self._simple_delete_one(
"push_rules",
{'user_name': user_name, 'rule_id': rule_id},
{'user_name': user_id, 'rule_id': rule_id},
desc="delete_push_rule",
)
self.get_push_rules_for_user.invalidate((user_name,))
self.get_push_rules_enabled_for_user.invalidate((user_name,))
self.get_push_rules_for_user.invalidate((user_id,))
self.get_push_rules_enabled_for_user.invalidate((user_id,))
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled):
def set_push_rule_enabled(self, user_id, rule_id, enabled):
ret = yield self.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
user_name, rule_id, enabled
user_id, rule_id, enabled
)
defer.returnValue(ret)
def _set_push_rule_enabled_txn(self, txn, user_name, rule_id, enabled):
def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled):
new_id = self._push_rules_enable_id_gen.get_next_txn(txn)
self._simple_upsert_txn(
txn,
"push_rules_enable",
{'user_name': user_name, 'rule_id': rule_id},
{'user_name': user_id, 'rule_id': rule_id},
{'enabled': 1 if enabled else 0},
{'id': new_id},
)
txn.call_after(
self.get_push_rules_for_user.invalidate, (user_name,)
self.get_push_rules_for_user.invalidate, (user_id,)
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)

View file

@ -80,7 +80,7 @@ class PusherStore(SQLBaseStore):
defer.returnValue(rows)
@defer.inlineCallbacks
def add_pusher(self, user_name, access_token, profile_tag, kind, app_id,
def add_pusher(self, user_id, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name,
pushkey, pushkey_ts, lang, data):
try:
@ -90,7 +90,7 @@ class PusherStore(SQLBaseStore):
dict(
app_id=app_id,
pushkey=pushkey,
user_name=user_name,
user_name=user_id,
),
dict(
access_token=access_token,
@ -112,38 +112,38 @@ class PusherStore(SQLBaseStore):
raise StoreError(500, "Problem creating pusher.")
@defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey_user_name(self, app_id, pushkey, user_name):
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
yield self._simple_delete_one(
"pushers",
{"app_id": app_id, "pushkey": pushkey, 'user_name': user_name},
desc="delete_pusher_by_app_id_pushkey_user_name",
{"app_id": app_id, "pushkey": pushkey, 'user_name': user_id},
desc="delete_pusher_by_app_id_pushkey_user_id",
)
@defer.inlineCallbacks
def update_pusher_last_token(self, app_id, pushkey, user_name, last_token):
def update_pusher_last_token(self, app_id, pushkey, user_id, last_token):
yield self._simple_update_one(
"pushers",
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
{'last_token': last_token},
desc="update_pusher_last_token",
)
@defer.inlineCallbacks
def update_pusher_last_token_and_success(self, app_id, pushkey, user_name,
def update_pusher_last_token_and_success(self, app_id, pushkey, user_id,
last_token, last_success):
yield self._simple_update_one(
"pushers",
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
{'last_token': last_token, 'last_success': last_success},
desc="update_pusher_last_token_and_success",
)
@defer.inlineCallbacks
def update_pusher_failing_since(self, app_id, pushkey, user_name,
def update_pusher_failing_since(self, app_id, pushkey, user_id,
failing_since):
yield self._simple_update_one(
"pushers",
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
{'failing_since': failing_since},
desc="update_pusher_failing_since",
)

View file

@ -14,7 +14,7 @@
# limitations under the License.
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
from synapse.util.caches import cache_counter, caches_by_name
from twisted.internet import defer
@ -33,6 +33,18 @@ class ReceiptsStore(SQLBaseStore):
self._receipts_stream_cache = _RoomStreamChangeCache()
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
return self._simple_select_list(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
},
retcols=("user_id", "event_id"),
desc="get_receipts_for_room",
)
@defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
"""Get receipts for multiple rooms for sending to clients.

View file

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
class RegistrationStore(SQLBaseStore):
@ -256,10 +256,10 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False)
@cachedInlineCallbacks()
def is_guest(self, user):
def is_guest(self, user_id):
res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
keyvalues={"name": user_id},
retcol="is_guest",
allow_none=True,
desc="is_guest",
@ -267,6 +267,26 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False)
@cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
inlineCallbacks=True)
def are_guests(self, user_ids):
sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
",".join("?" for _ in user_ids),
)
rows = yield self._execute(
"are_guests", self.cursor_to_dict, sql, *user_ids
)
result = {user_id: False for user_id in user_ids}
result.update({
row["name"]: bool(row["is_guest"])
for row in rows
})
defer.returnValue(result)
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id"

View file

@ -287,6 +287,7 @@ class RoomMemberStore(SQLBaseStore):
txn.execute(sql, (user_id, room_id))
yield self.runInteraction("forget_membership", f)
self.was_forgotten_at.invalidate_all()
self.who_forgot_in_room.invalidate_all()
self.did_forget.invalidate((user_id, room_id))
@cachedInlineCallbacks(num_args=2)
@ -336,3 +337,15 @@ class RoomMemberStore(SQLBaseStore):
return rows[0][0]
forgot = yield self.runInteraction("did_forget_membership_at", f)
defer.returnValue(forgot == 1)
@cached()
def who_forgot_in_room(self, room_id):
return self._simple_select_list(
table="room_memberships",
retcols=("user_id", "event_id"),
keyvalues={
"room_id": room_id,
"forgotten": 1,
},
desc="who_forgot"
)

View file

@ -22,6 +22,9 @@ import logging
logger = logging.getLogger(__name__)
MAX_LIMIT = 1000
class SourcePaginationConfig(object):
"""A configuration object which stores pagination parameters for a
@ -32,7 +35,7 @@ class SourcePaginationConfig(object):
self.from_key = from_key
self.to_key = to_key
self.direction = 'f' if direction == 'f' else 'b'
self.limit = int(limit) if limit is not None else None
self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
def __repr__(self):
return (
@ -49,7 +52,7 @@ class PaginationConfig(object):
self.from_token = from_token
self.to_token = to_token
self.direction = 'f' if direction == 'f' else 'b'
self.limit = int(limit) if limit is not None else None
self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
@classmethod
def from_request(cls, request, raise_invalid_params=True,

View file

@ -29,6 +29,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
def setUp(self):
self.service = ApplicationService(
id="unique_identifier",
url="some_url",
token="some_token",
namespaces={

View file

@ -1,141 +0,0 @@
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from tests import unittest
from synapse.api.constants import EventTypes
from synapse.events import FrozenEvent
from synapse.handlers.federation import FederationHandler
from mock import NonCallableMock, ANY, Mock
from ..utils import setup_test_homeserver
class FederationTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.state_handler = NonCallableMock(spec_set=[
"compute_event_context",
])
self.auth = NonCallableMock(spec_set=[
"check",
"check_host_in_room",
])
self.hostname = "test"
hs = yield setup_test_homeserver(
self.hostname,
datastore=NonCallableMock(spec_set=[
"persist_event",
"store_room",
"get_room",
"get_destination_retry_timings",
"set_destination_retry_timings",
"have_events",
"get_users_in_room",
"bulk_get_push_rules",
"get_current_state",
"set_push_actions_for_event_and_users",
"is_guest",
"get_state_for_events",
]),
resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]),
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
handlers=NonCallableMock(spec_set=[
"room_member_handler",
"federation_handler",
]),
auth=self.auth,
state_handler=self.state_handler,
keyring=Mock(),
)
self.datastore = hs.get_datastore()
self.handlers = hs.get_handlers()
self.notifier = hs.get_notifier()
self.hs = hs
self.handlers.federation_handler = FederationHandler(self.hs)
self.datastore.get_state_for_events.return_value = {"$a:b": {}}
@defer.inlineCallbacks
def test_msg(self):
pdu = FrozenEvent({
"type": EventTypes.Message,
"room_id": "foo",
"content": {"msgtype": u"fooo"},
"origin_server_ts": 0,
"event_id": "$a:b",
"user_id":"@a:b",
"origin": "b",
"auth_events": [],
"hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
})
self.datastore.persist_event.return_value = defer.succeed((1,1))
self.datastore.get_room.return_value = defer.succeed(True)
self.datastore.get_users_in_room.return_value = ["@a:b"]
self.datastore.bulk_get_push_rules.return_value = {}
self.datastore.get_current_state.return_value = {}
self.auth.check_host_in_room.return_value = defer.succeed(True)
retry_timings_res = {
"destination": "",
"retry_last_ts": 0,
"retry_interval": 0,
}
self.datastore.get_destination_retry_timings.return_value = (
defer.succeed(retry_timings_res)
)
def have_events(event_ids):
return defer.succeed({})
self.datastore.have_events.side_effect = have_events
def annotate(ev, old_state=None, outlier=False):
context = Mock()
context.current_state = {}
context.auth_events = {}
return defer.succeed(context)
self.state_handler.compute_event_context.side_effect = annotate
yield self.handlers.federation_handler.on_receive_pdu(
"fo", pdu, False
)
self.datastore.persist_event.assert_called_once_with(
ANY,
is_new_state=True,
backfilled=False,
current_state=None,
context=ANY,
)
self.state_handler.compute_event_context.assert_called_once_with(
ANY, old_state=None, outlier=False
)
self.auth.check.assert_called_once_with(ANY, auth_events={})
self.notifier.on_new_room_event.assert_called_once_with(
ANY, 1, 1, extra_users=[]
)

View file

@ -1,418 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .. import unittest
from synapse.api.constants import EventTypes, Membership
from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler
from synapse.handlers.profile import ProfileHandler
from synapse.types import UserID
from ..utils import setup_test_homeserver
from mock import Mock, NonCallableMock
class RoomMemberHandlerTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hostname = "red"
hs = yield setup_test_homeserver(
self.hostname,
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
datastore=NonCallableMock(spec_set=[
"persist_event",
"get_room_member",
"get_room",
"store_room",
"get_latest_events_in_room",
"add_event_hashes",
"get_users_in_room",
"bulk_get_push_rules",
"get_current_state",
"set_push_actions_for_event_and_users",
"get_state_for_events",
"is_guest",
]),
resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]),
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
handlers=NonCallableMock(spec_set=[
"room_member_handler",
"profile_handler",
"federation_handler",
]),
auth=NonCallableMock(spec_set=[
"check",
"add_auth_events",
"check_host_in_room",
]),
state_handler=NonCallableMock(spec_set=[
"compute_event_context",
"get_current_state",
]),
)
self.federation = NonCallableMock(spec_set=[
"handle_new_event",
"send_invite",
"get_state_for_room",
])
self.datastore = hs.get_datastore()
self.handlers = hs.get_handlers()
self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
self.auth = hs.get_auth()
self.hs = hs
self.handlers.federation_handler = self.federation
self.distributor.declare("collect_presencelike_data")
self.handlers.room_member_handler = RoomMemberHandler(self.hs)
self.handlers.profile_handler = ProfileHandler(self.hs)
self.room_member_handler = self.handlers.room_member_handler
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
self.datastore.persist_event.return_value = (1,1)
self.datastore.add_event_hashes.return_value = []
self.datastore.get_users_in_room.return_value = ["@bob:red"]
self.datastore.bulk_get_push_rules.return_value = {}
@defer.inlineCallbacks
def test_invite(self):
room_id = "!foo:red"
user_id = "@bob:red"
target_user_id = "@red:blue"
content = {"membership": Membership.INVITE}
builder = self.hs.get_event_builder_factory().new({
"type": EventTypes.Member,
"sender": user_id,
"state_key": target_user_id,
"room_id": room_id,
"content": content,
})
self.datastore.get_latest_events_in_room.return_value = (
defer.succeed([])
)
self.datastore.get_current_state.return_value = {}
self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
def annotate(_):
ctx = Mock()
ctx.current_state = {
(EventTypes.Member, "@alice:green"): self._create_member(
user_id="@alice:green",
room_id=room_id,
),
(EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
),
}
ctx.prev_state_events = []
return defer.succeed(ctx)
self.state_handler.compute_event_context.side_effect = annotate
def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[
(EventTypes.Member, "@bob:red")
]
return defer.succeed(True)
self.auth.add_auth_events.side_effect = add_auth
def send_invite(domain, event):
return defer.succeed(event)
self.federation.send_invite.side_effect = send_invite
room_handler = self.room_member_handler
event, context = yield room_handler._create_new_client_event(
builder
)
yield room_handler.change_membership(event, context)
self.state_handler.compute_event_context.assert_called_once_with(
builder
)
self.auth.add_auth_events.assert_called_once_with(
builder, context
)
self.federation.send_invite.assert_called_once_with(
"blue", event,
)
self.datastore.persist_event.assert_called_once_with(
event, context=context,
)
self.notifier.on_new_room_event.assert_called_once_with(
event, 1, 1, extra_users=[UserID.from_string(target_user_id)]
)
self.assertFalse(self.datastore.get_room.called)
self.assertFalse(self.datastore.store_room.called)
self.assertFalse(self.federation.get_state_for_room.called)
@defer.inlineCallbacks
def test_simple_join(self):
room_id = "!foo:red"
user_id = "@bob:red"
user = UserID.from_string(user_id)
join_signal_observer = Mock()
self.distributor.observe("user_joined_room", join_signal_observer)
builder = self.hs.get_event_builder_factory().new({
"type": EventTypes.Member,
"sender": user_id,
"state_key": user_id,
"room_id": room_id,
"content": {"membership": Membership.JOIN},
})
self.datastore.get_latest_events_in_room.return_value = (
defer.succeed([])
)
self.datastore.get_current_state.return_value = {}
self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
def annotate(_):
ctx = Mock()
ctx.current_state = {
(EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
membership=Membership.INVITE
),
}
ctx.prev_state_events = []
return defer.succeed(ctx)
self.state_handler.compute_event_context.side_effect = annotate
def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[
(EventTypes.Member, "@bob:red")
]
return defer.succeed(True)
self.auth.add_auth_events.side_effect = add_auth
room_handler = self.room_member_handler
event, context = yield room_handler._create_new_client_event(
builder
)
# Actual invocation
yield room_handler.change_membership(event, context)
self.federation.handle_new_event.assert_called_once_with(
event, destinations=set()
)
self.datastore.persist_event.assert_called_once_with(
event, context=context
)
self.notifier.on_new_room_event.assert_called_once_with(
event, 1, 1, extra_users=[user]
)
join_signal_observer.assert_called_with(
user=user, room_id=room_id
)
def _create_member(self, user_id, room_id, membership=Membership.JOIN):
builder = self.hs.get_event_builder_factory().new({
"type": EventTypes.Member,
"sender": user_id,
"state_key": user_id,
"room_id": room_id,
"content": {"membership": membership},
})
return builder.build()
@defer.inlineCallbacks
def test_simple_leave(self):
room_id = "!foo:red"
user_id = "@bob:red"
user = UserID.from_string(user_id)
builder = self.hs.get_event_builder_factory().new({
"type": EventTypes.Member,
"sender": user_id,
"state_key": user_id,
"room_id": room_id,
"content": {"membership": Membership.LEAVE},
})
self.datastore.get_latest_events_in_room.return_value = (
defer.succeed([])
)
self.datastore.get_current_state.return_value = {}
self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
def annotate(_):
ctx = Mock()
ctx.current_state = {
(EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
membership=Membership.JOIN
),
}
ctx.prev_state_events = []
return defer.succeed(ctx)
self.state_handler.compute_event_context.side_effect = annotate
def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[
(EventTypes.Member, "@bob:red")
]
return defer.succeed(True)
self.auth.add_auth_events.side_effect = add_auth
room_handler = self.room_member_handler
event, context = yield room_handler._create_new_client_event(
builder
)
leave_signal_observer = Mock()
self.distributor.observe("user_left_room", leave_signal_observer)
# Actual invocation
yield room_handler.change_membership(event, context)
self.federation.handle_new_event.assert_called_once_with(
event, destinations=set(['red'])
)
self.datastore.persist_event.assert_called_once_with(
event, context=context
)
self.notifier.on_new_room_event.assert_called_once_with(
event, 1, 1, extra_users=[user]
)
leave_signal_observer.assert_called_with(
user=user, room_id=room_id
)
class RoomCreationTest(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hostname = "red"
hs = yield setup_test_homeserver(
self.hostname,
datastore=NonCallableMock(spec_set=[
"store_room",
"snapshot_room",
"persist_event",
"get_joined_hosts_for_room",
]),
http_client=NonCallableMock(spec_set=[]),
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
handlers=NonCallableMock(spec_set=[
"room_creation_handler",
"message_handler",
]),
auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
)
self.federation = NonCallableMock(spec_set=[
"handle_new_event",
])
self.handlers = hs.get_handlers()
self.handlers.room_creation_handler = RoomCreationHandler(hs)
self.room_creation_handler = self.handlers.room_creation_handler
self.message_handler = self.handlers.message_handler
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
@defer.inlineCallbacks
def test_room_creation(self):
user_id = "@foo:red"
room_id = "!bobs_room:red"
config = {"visibility": "private"}
yield self.room_creation_handler.create_room(
user_id=user_id,
room_id=room_id,
config=config,
)
self.assertTrue(self.message_handler.create_and_send_event.called)
event_dicts = [
e[0][0]
for e in self.message_handler.create_and_send_event.call_args_list
]
self.assertTrue(len(event_dicts) > 3)
self.assertDictContainsSubset(
{
"type": EventTypes.Create,
"sender": user_id,
"room_id": room_id,
},
event_dicts[0]
)
self.assertEqual(user_id, event_dicts[0]["content"]["creator"])
self.assertDictContainsSubset(
{
"type": EventTypes.Member,
"sender": user_id,
"room_id": room_id,
"state_key": user_id,
},
event_dicts[1]
)
self.assertEqual(
Membership.JOIN,
event_dicts[1]["content"]["membership"]
)

View file

@ -12,12 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
from synapse.config._base import ConfigError
from tests import unittest
from twisted.internet import defer
from tests.utils import setup_test_homeserver
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.server import HomeServer
from synapse.storage.appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore
)
@ -26,7 +27,6 @@ import json
import os
import yaml
from mock import Mock
from tests.utils import SQLiteMemoryDbPool, MockClock
class ApplicationServiceStoreTestCase(unittest.TestCase):
@ -41,9 +41,16 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self.as_token = "token1"
self.as_url = "some_url"
self._add_appservice(self.as_token, self.as_url, "some_hs_token", "bob")
self._add_appservice("token2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "some_url", "some_hs_token", "bob")
self.as_id = "as1"
self._add_appservice(
self.as_token,
self.as_id,
self.as_url,
"some_hs_token",
"bob"
)
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts
self.store = ApplicationServiceStore(hs)
@ -55,9 +62,9 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
except:
pass
def _add_appservice(self, as_token, url, hs_token, sender):
def _add_appservice(self, as_token, id, url, hs_token, sender):
as_yaml = dict(url=url, as_token=as_token, hs_token=hs_token,
sender_localpart=sender, namespaces={})
id=id, sender_localpart=sender, namespaces={})
# use the token as the filename
with open(as_token, 'w') as outfile:
outfile.write(yaml.dump(as_yaml))
@ -74,6 +81,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self.as_token
)
self.assertEquals(stored_service.token, self.as_token)
self.assertEquals(stored_service.id, self.as_id)
self.assertEquals(stored_service.url, self.as_url)
self.assertEquals(
stored_service.namespaces[ApplicationService.NS_ALIASES],
@ -110,34 +118,34 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
{
"token": "token1",
"url": "https://matrix-as.org",
"id": "token1"
"id": "id_1"
},
{
"token": "alpha_tok",
"url": "https://alpha.com",
"id": "alpha_tok"
"id": "id_alpha"
},
{
"token": "beta_tok",
"url": "https://beta.com",
"id": "beta_tok"
"id": "id_beta"
},
{
"token": "delta_tok",
"url": "https://delta.com",
"id": "delta_tok"
"token": "gamma_tok",
"url": "https://gamma.com",
"id": "id_gamma"
},
]
for s in self.as_list:
yield self._add_service(s["url"], s["token"])
yield self._add_service(s["url"], s["token"], s["id"])
self.as_yaml_files = []
self.store = TestTransactionStore(hs)
def _add_service(self, url, as_token):
def _add_service(self, url, as_token, id):
as_yaml = dict(url=url, as_token=as_token, hs_token="something",
sender_localpart="a_sender", namespaces={})
id=id, sender_localpart="a_sender", namespaces={})
# use the token as the filename
with open(as_token, 'w') as outfile:
outfile.write(yaml.dump(as_yaml))
@ -405,3 +413,64 @@ class TestTransactionStore(ApplicationServiceTransactionStore,
def __init__(self, hs):
super(TestTransactionStore, self).__init__(hs)
class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
def _write_config(self, suffix, **kwargs):
vals = {
"id": "id" + suffix,
"url": "url" + suffix,
"as_token": "as_token" + suffix,
"hs_token": "hs_token" + suffix,
"sender_localpart": "sender_localpart" + suffix,
"namespaces": {},
}
vals.update(kwargs)
_, path = tempfile.mkstemp(prefix="as_config")
with open(path, "w") as f:
f.write(yaml.dump(vals))
return path
@defer.inlineCallbacks
def test_unique_works(self):
f1 = self._write_config(suffix="1")
f2 = self._write_config(suffix="2")
config = Mock(app_service_config_files=[f1, f2])
hs = yield setup_test_homeserver(config=config)
ApplicationServiceStore(hs)
@defer.inlineCallbacks
def test_duplicate_ids(self):
f1 = self._write_config(id="id", suffix="1")
f2 = self._write_config(id="id", suffix="2")
config = Mock(app_service_config_files=[f1, f2])
hs = yield setup_test_homeserver(config=config)
with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs)
e = cm.exception
self.assertIn(f1, e.message)
self.assertIn(f2, e.message)
self.assertIn("id", e.message)
@defer.inlineCallbacks
def test_duplicate_as_tokens(self):
f1 = self._write_config(as_token="as_token", suffix="1")
f2 = self._write_config(as_token="as_token", suffix="2")
config = Mock(app_service_config_files=[f1, f2])
hs = yield setup_test_homeserver(config=config)
with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs)
e = cm.exception
self.assertIn(f1, e.message)
self.assertIn(f2, e.message)
self.assertIn("as_token", e.message)