Add type hints for HTTP and email pushers. (#8880)

This commit is contained in:
Patrick Cloke 2020-12-07 09:59:38 -05:00 committed by GitHub
parent 02e588856a
commit 92d87c6882
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 168 additions and 101 deletions

1
changelog.d/8880.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to push module.

View file

@ -55,7 +55,10 @@ files =
synapse/metrics, synapse/metrics,
synapse/module_api, synapse/module_api,
synapse/notifier.py, synapse/notifier.py,
synapse/push/emailpusher.py,
synapse/push/httppusher.py,
synapse/push/mailer.py, synapse/push/mailer.py,
synapse/push/pusher.py,
synapse/push/pusherpool.py, synapse/push/pusherpool.py,
synapse/push/push_rule_evaluator.py, synapse/push/push_rule_evaluator.py,
synapse/replication, synapse/replication,

View file

@ -13,6 +13,56 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
from typing import TYPE_CHECKING, Any, Dict, Optional
from synapse.types import RoomStreamToken
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
class Pusher(metaclass=abc.ABCMeta):
def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
self.hs = hs
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pusher_id = pusherdict["id"]
self.user_id = pusherdict["user_name"]
self.app_id = pusherdict["app_id"]
self.pushkey = pusherdict["pushkey"]
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
# should honour this rather than just looking for anything higher
# because of potential out-of-order event serialisation. This starts
# off as None though as we don't know any better.
self.max_stream_ordering = None # type: Optional[int]
@abc.abstractmethod
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
raise NotImplementedError()
@abc.abstractmethod
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
raise NotImplementedError()
@abc.abstractmethod
def on_started(self, have_notifs: bool) -> None:
"""Called when this pusher has been started.
Args:
should_check_for_notifs: Whether we should immediately
check for push to send. Set to False only if it's known there
is nothing to send
"""
raise NotImplementedError()
@abc.abstractmethod
def on_stop(self) -> None:
raise NotImplementedError()
class PusherConfigException(Exception): class PusherConfigException(Exception):
"""An error occurred when creating a pusher.""" """An error occurred when creating a pusher."""

View file

@ -14,12 +14,19 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from twisted.internet.base import DelayedCall
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher
from synapse.push.mailer import Mailer
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# The amount of time we always wait before ever emailing about a notification # The amount of time we always wait before ever emailing about a notification
@ -46,7 +53,7 @@ THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000
INCLUDE_ALL_UNREAD_NOTIFS = False INCLUDE_ALL_UNREAD_NOTIFS = False
class EmailPusher: class EmailPusher(Pusher):
""" """
A pusher that sends email notifications about events (approximately) A pusher that sends email notifications about events (approximately)
when they happen. when they happen.
@ -54,37 +61,31 @@ class EmailPusher:
factor out the common parts factor out the common parts
""" """
def __init__(self, hs, pusherdict, mailer): def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any], mailer: Mailer):
self.hs = hs super().__init__(hs, pusherdict)
self.mailer = mailer self.mailer = mailer
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pusher_id = pusherdict["id"]
self.user_id = pusherdict["user_name"]
self.app_id = pusherdict["app_id"]
self.email = pusherdict["pushkey"] self.email = pusherdict["pushkey"]
self.last_stream_ordering = pusherdict["last_stream_ordering"] self.last_stream_ordering = pusherdict["last_stream_ordering"]
self.timed_call = None self.timed_call = None # type: Optional[DelayedCall]
self.throttle_params = None self.throttle_params = {} # type: Dict[str, Dict[str, int]]
self._inited = False
# See httppusher
self.max_stream_ordering = None
self._is_processing = False self._is_processing = False
def on_started(self, should_check_for_notifs): def on_started(self, should_check_for_notifs: bool) -> None:
"""Called when this pusher has been started. """Called when this pusher has been started.
Args: Args:
should_check_for_notifs (bool): Whether we should immediately should_check_for_notifs: Whether we should immediately
check for push to send. Set to False only if it's known there check for push to send. Set to False only if it's known there
is nothing to send is nothing to send
""" """
if should_check_for_notifs and self.mailer is not None: if should_check_for_notifs and self.mailer is not None:
self._start_processing() self._start_processing()
def on_stop(self): def on_stop(self) -> None:
if self.timed_call: if self.timed_call:
try: try:
self.timed_call.cancel() self.timed_call.cancel()
@ -92,7 +93,7 @@ class EmailPusher:
pass pass
self.timed_call = None self.timed_call = None
def on_new_notifications(self, max_token: RoomStreamToken): def on_new_notifications(self, max_token: RoomStreamToken) -> None:
# We just use the minimum stream ordering and ignore the vector clock # We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector # component. This is safe to do as long as we *always* ignore the vector
# clock components. # clock components.
@ -106,23 +107,23 @@ class EmailPusher:
self.max_stream_ordering = max_stream_ordering self.max_stream_ordering = max_stream_ordering
self._start_processing() self._start_processing()
def on_new_receipts(self, min_stream_id, max_stream_id): def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# We could wake up and cancel the timer but there tend to be quite a # We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the # lot of read receipts so it's probably less work to just let the
# timer fire # timer fire
pass pass
def on_timer(self): def on_timer(self) -> None:
self.timed_call = None self.timed_call = None
self._start_processing() self._start_processing()
def _start_processing(self): def _start_processing(self) -> None:
if self._is_processing: if self._is_processing:
return return
run_as_background_process("emailpush.process", self._process) run_as_background_process("emailpush.process", self._process)
def _pause_processing(self): def _pause_processing(self) -> None:
"""Used by tests to temporarily pause processing of events. """Used by tests to temporarily pause processing of events.
Asserts that its not currently processing. Asserts that its not currently processing.
@ -130,25 +131,26 @@ class EmailPusher:
assert not self._is_processing assert not self._is_processing
self._is_processing = True self._is_processing = True
def _resume_processing(self): def _resume_processing(self) -> None:
"""Used by tests to resume processing of events after pausing. """Used by tests to resume processing of events after pausing.
""" """
assert self._is_processing assert self._is_processing
self._is_processing = False self._is_processing = False
self._start_processing() self._start_processing()
async def _process(self): async def _process(self) -> None:
# we should never get here if we are already processing # we should never get here if we are already processing
assert not self._is_processing assert not self._is_processing
try: try:
self._is_processing = True self._is_processing = True
if self.throttle_params is None: if not self._inited:
# this is our first loop: load up the throttle params # this is our first loop: load up the throttle params
self.throttle_params = await self.store.get_throttle_params_by_room( self.throttle_params = await self.store.get_throttle_params_by_room(
self.pusher_id self.pusher_id
) )
self._inited = True
# if the max ordering changes while we're running _unsafe_process, # if the max ordering changes while we're running _unsafe_process,
# call it again, and so on until we've caught up. # call it again, and so on until we've caught up.
@ -163,17 +165,19 @@ class EmailPusher:
finally: finally:
self._is_processing = False self._is_processing = False
async def _unsafe_process(self): async def _unsafe_process(self) -> None:
""" """
Main logic of the push loop without the wrapper function that sets Main logic of the push loop without the wrapper function that sets
up logging, measures and guards against multiple instances of it up logging, measures and guards against multiple instances of it
being run. being run.
""" """
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
fn = self.store.get_unread_push_actions_for_user_in_range_for_email assert self.max_stream_ordering is not None
unprocessed = await fn(self.user_id, start, self.max_stream_ordering) unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
self.user_id, start, self.max_stream_ordering
)
soonest_due_at = None soonest_due_at = None # type: Optional[int]
if not unprocessed: if not unprocessed:
await self.save_last_stream_ordering_and_success(self.max_stream_ordering) await self.save_last_stream_ordering_and_success(self.max_stream_ordering)
@ -230,7 +234,9 @@ class EmailPusher:
self.seconds_until(soonest_due_at), self.on_timer self.seconds_until(soonest_due_at), self.on_timer
) )
async def save_last_stream_ordering_and_success(self, last_stream_ordering): async def save_last_stream_ordering_and_success(
self, last_stream_ordering: Optional[int]
) -> None:
if last_stream_ordering is None: if last_stream_ordering is None:
# This happens if we haven't yet processed anything # This happens if we haven't yet processed anything
return return
@ -248,28 +254,30 @@ class EmailPusher:
# lets just stop and return. # lets just stop and return.
self.on_stop() self.on_stop()
def seconds_until(self, ts_msec): def seconds_until(self, ts_msec: int) -> float:
secs = (ts_msec - self.clock.time_msec()) / 1000 secs = (ts_msec - self.clock.time_msec()) / 1000
return max(secs, 0) return max(secs, 0)
def get_room_throttle_ms(self, room_id): def get_room_throttle_ms(self, room_id: str) -> int:
if room_id in self.throttle_params: if room_id in self.throttle_params:
return self.throttle_params[room_id]["throttle_ms"] return self.throttle_params[room_id]["throttle_ms"]
else: else:
return 0 return 0
def get_room_last_sent_ts(self, room_id): def get_room_last_sent_ts(self, room_id: str) -> int:
if room_id in self.throttle_params: if room_id in self.throttle_params:
return self.throttle_params[room_id]["last_sent_ts"] return self.throttle_params[room_id]["last_sent_ts"]
else: else:
return 0 return 0
def room_ready_to_notify_at(self, room_id): def room_ready_to_notify_at(self, room_id: str) -> int:
""" """
Determines whether throttling should prevent us from sending an email Determines whether throttling should prevent us from sending an email
for the given room for the given room
Returns: The timestamp when we are next allowed to send an email notif
for this room Returns:
The timestamp when we are next allowed to send an email notif
for this room
""" """
last_sent_ts = self.get_room_last_sent_ts(room_id) last_sent_ts = self.get_room_last_sent_ts(room_id)
throttle_ms = self.get_room_throttle_ms(room_id) throttle_ms = self.get_room_throttle_ms(room_id)
@ -277,7 +285,9 @@ class EmailPusher:
may_send_at = last_sent_ts + throttle_ms may_send_at = last_sent_ts + throttle_ms
return may_send_at return may_send_at
async def sent_notif_update_throttle(self, room_id, notified_push_action): async def sent_notif_update_throttle(
self, room_id: str, notified_push_action: dict
) -> None:
# We have sent a notification, so update the throttle accordingly. # We have sent a notification, so update the throttle accordingly.
# If the event that triggered the notif happened more than # If the event that triggered the notif happened more than
# THROTTLE_RESET_AFTER_MS after the previous one that triggered a # THROTTLE_RESET_AFTER_MS after the previous one that triggered a
@ -315,7 +325,7 @@ class EmailPusher:
self.pusher_id, room_id, self.throttle_params[room_id] self.pusher_id, room_id, self.throttle_params[room_id]
) )
async def send_notification(self, push_actions, reason): async def send_notification(self, push_actions: List[dict], reason: dict) -> None:
logger.info("Sending notif email for user %r", self.user_id) logger.info("Sending notif email for user %r", self.user_id)
await self.mailer.send_notification_mail( await self.mailer.send_notification_mail(

View file

@ -15,19 +15,24 @@
# limitations under the License. # limitations under the License.
import logging import logging
import urllib.parse import urllib.parse
from typing import TYPE_CHECKING, Any, Dict, Iterable, Union
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException from synapse.push import Pusher, PusherConfigException
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from . import push_rule_evaluator, push_tools from . import push_rule_evaluator, push_tools
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
http_push_processed_counter = Counter( http_push_processed_counter = Counter(
@ -51,24 +56,18 @@ http_badges_failed_counter = Counter(
) )
class HttpPusher: class HttpPusher(Pusher):
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
MAX_BACKOFF_SEC = 60 * 60 MAX_BACKOFF_SEC = 60 * 60
# This one's in ms because we compare it against the clock # This one's in ms because we compare it against the clock
GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000 GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
def __init__(self, hs, pusherdict): def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
self.hs = hs super().__init__(hs, pusherdict)
self.store = self.hs.get_datastore()
self.storage = self.hs.get_storage() self.storage = self.hs.get_storage()
self.clock = self.hs.get_clock()
self.state_handler = self.hs.get_state_handler()
self.user_id = pusherdict["user_name"]
self.app_id = pusherdict["app_id"]
self.app_display_name = pusherdict["app_display_name"] self.app_display_name = pusherdict["app_display_name"]
self.device_display_name = pusherdict["device_display_name"] self.device_display_name = pusherdict["device_display_name"]
self.pushkey = pusherdict["pushkey"]
self.pushkey_ts = pusherdict["ts"] self.pushkey_ts = pusherdict["ts"]
self.data = pusherdict["data"] self.data = pusherdict["data"]
self.last_stream_ordering = pusherdict["last_stream_ordering"] self.last_stream_ordering = pusherdict["last_stream_ordering"]
@ -78,13 +77,6 @@ class HttpPusher:
self._is_processing = False self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
# should honour this rather than just looking for anything higher
# because of potential out-of-order event serialisation. This starts
# off as None though as we don't know any better.
self.max_stream_ordering = None
if "data" not in pusherdict: if "data" not in pusherdict:
raise PusherConfigException("No 'data' key for HTTP pusher") raise PusherConfigException("No 'data' key for HTTP pusher")
self.data = pusherdict["data"] self.data = pusherdict["data"]
@ -119,18 +111,18 @@ class HttpPusher:
self.data_minus_url.update(self.data) self.data_minus_url.update(self.data)
del self.data_minus_url["url"] del self.data_minus_url["url"]
def on_started(self, should_check_for_notifs): def on_started(self, should_check_for_notifs: bool) -> None:
"""Called when this pusher has been started. """Called when this pusher has been started.
Args: Args:
should_check_for_notifs (bool): Whether we should immediately should_check_for_notifs: Whether we should immediately
check for push to send. Set to False only if it's known there check for push to send. Set to False only if it's known there
is nothing to send is nothing to send
""" """
if should_check_for_notifs: if should_check_for_notifs:
self._start_processing() self._start_processing()
def on_new_notifications(self, max_token: RoomStreamToken): def on_new_notifications(self, max_token: RoomStreamToken) -> None:
# We just use the minimum stream ordering and ignore the vector clock # We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector # component. This is safe to do as long as we *always* ignore the vector
# clock components. # clock components.
@ -141,14 +133,14 @@ class HttpPusher:
) )
self._start_processing() self._start_processing()
def on_new_receipts(self, min_stream_id, max_stream_id): def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# Note that the min here shouldn't be relied upon to be accurate. # Note that the min here shouldn't be relied upon to be accurate.
# We could check the receipts are actually m.read receipts here, # We could check the receipts are actually m.read receipts here,
# but currently that's the only type of receipt anyway... # but currently that's the only type of receipt anyway...
run_as_background_process("http_pusher.on_new_receipts", self._update_badge) run_as_background_process("http_pusher.on_new_receipts", self._update_badge)
async def _update_badge(self): async def _update_badge(self) -> None:
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
# to be largely redundant. perhaps we can remove it. # to be largely redundant. perhaps we can remove it.
badge = await push_tools.get_badge_count( badge = await push_tools.get_badge_count(
@ -158,10 +150,10 @@ class HttpPusher:
) )
await self._send_badge(badge) await self._send_badge(badge)
def on_timer(self): def on_timer(self) -> None:
self._start_processing() self._start_processing()
def on_stop(self): def on_stop(self) -> None:
if self.timed_call: if self.timed_call:
try: try:
self.timed_call.cancel() self.timed_call.cancel()
@ -169,13 +161,13 @@ class HttpPusher:
pass pass
self.timed_call = None self.timed_call = None
def _start_processing(self): def _start_processing(self) -> None:
if self._is_processing: if self._is_processing:
return return
run_as_background_process("httppush.process", self._process) run_as_background_process("httppush.process", self._process)
async def _process(self): async def _process(self) -> None:
# we should never get here if we are already processing # we should never get here if we are already processing
assert not self._is_processing assert not self._is_processing
@ -194,7 +186,7 @@ class HttpPusher:
finally: finally:
self._is_processing = False self._is_processing = False
async def _unsafe_process(self): async def _unsafe_process(self) -> None:
""" """
Looks for unset notifications and dispatch them, in order Looks for unset notifications and dispatch them, in order
Never call this directly: use _process which will only allow this to Never call this directly: use _process which will only allow this to
@ -202,6 +194,7 @@ class HttpPusher:
""" """
fn = self.store.get_unread_push_actions_for_user_in_range_for_http fn = self.store.get_unread_push_actions_for_user_in_range_for_http
assert self.max_stream_ordering is not None
unprocessed = await fn( unprocessed = await fn(
self.user_id, self.last_stream_ordering, self.max_stream_ordering self.user_id, self.last_stream_ordering, self.max_stream_ordering
) )
@ -271,17 +264,12 @@ class HttpPusher:
) )
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"] self.last_stream_ordering = push_action["stream_ordering"]
pusher_still_exists = await self.store.update_pusher_last_stream_ordering( await self.store.update_pusher_last_stream_ordering(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_id, self.user_id,
self.last_stream_ordering, self.last_stream_ordering,
) )
if not pusher_still_exists:
# The pusher has been deleted while we were processing, so
# lets just stop and return.
self.on_stop()
return
self.failing_since = None self.failing_since = None
await self.store.update_pusher_failing_since( await self.store.update_pusher_failing_since(
@ -297,7 +285,7 @@ class HttpPusher:
) )
break break
async def _process_one(self, push_action): async def _process_one(self, push_action: dict) -> bool:
if "notify" not in push_action["actions"]: if "notify" not in push_action["actions"]:
return True return True
@ -328,7 +316,9 @@ class HttpPusher:
await self.hs.remove_pusher(self.app_id, pk, self.user_id) await self.hs.remove_pusher(self.app_id, pk, self.user_id)
return True return True
async def _build_notification_dict(self, event, tweaks, badge): async def _build_notification_dict(
self, event: EventBase, tweaks: Dict[str, bool], badge: int
) -> Dict[str, Any]:
priority = "low" priority = "low"
if ( if (
event.type == EventTypes.Encrypted event.type == EventTypes.Encrypted
@ -358,9 +348,7 @@ class HttpPusher:
} }
return d return d
ctx = await push_tools.get_context_for_event( ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id)
self.storage, self.state_handler, event, self.user_id
)
d = { d = {
"notification": { "notification": {
@ -400,7 +388,9 @@ class HttpPusher:
return d return d
async def dispatch_push(self, event, tweaks, badge): async def dispatch_push(
self, event: EventBase, tweaks: Dict[str, bool], badge: int
) -> Union[bool, Iterable[str]]:
notification_dict = await self._build_notification_dict(event, tweaks, badge) notification_dict = await self._build_notification_dict(event, tweaks, badge)
if not notification_dict: if not notification_dict:
return [] return []

View file

@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict
from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage from synapse.storage import Storage
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
@ -46,7 +49,9 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
return badge return badge
async def get_context_for_event(storage: Storage, state_handler, ev, user_id): async def get_context_for_event(
storage: Storage, ev: EventBase, user_id: str
) -> Dict[str, str]:
ctx = {} ctx = {}
room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id) room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id)

View file

@ -14,25 +14,31 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from synapse.push import Pusher
from synapse.push.emailpusher import EmailPusher from synapse.push.emailpusher import EmailPusher
from synapse.push.httppusher import HttpPusher
from synapse.push.mailer import Mailer from synapse.push.mailer import Mailer
from .httppusher import HttpPusher if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PusherFactory: class PusherFactory:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.config = hs.config self.config = hs.config
self.pusher_types = {"http": HttpPusher} self.pusher_types = {
"http": HttpPusher
} # type: Dict[str, Callable[[HomeServer, dict], Pusher]]
logger.info("email enable notifs: %r", hs.config.email_enable_notifs) logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs: if hs.config.email_enable_notifs:
self.mailers = {} # app_name -> Mailer self.mailers = {} # type: Dict[str, Mailer]
self._notif_template_html = hs.config.email_notif_template_html self._notif_template_html = hs.config.email_notif_template_html
self._notif_template_text = hs.config.email_notif_template_text self._notif_template_text = hs.config.email_notif_template_text
@ -41,7 +47,7 @@ class PusherFactory:
logger.info("defined email pusher type") logger.info("defined email pusher type")
def create_pusher(self, pusherdict): def create_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
kind = pusherdict["kind"] kind = pusherdict["kind"]
f = self.pusher_types.get(kind, None) f = self.pusher_types.get(kind, None)
if not f: if not f:
@ -49,7 +55,9 @@ class PusherFactory:
logger.debug("creating %s pusher for %r", kind, pusherdict) logger.debug("creating %s pusher for %r", kind, pusherdict)
return f(self.hs, pusherdict) return f(self.hs, pusherdict)
def _create_email_pusher(self, _hs, pusherdict): def _create_email_pusher(
self, _hs: "HomeServer", pusherdict: Dict[str, Any]
) -> EmailPusher:
app_name = self._app_name_from_pusherdict(pusherdict) app_name = self._app_name_from_pusherdict(pusherdict)
mailer = self.mailers.get(app_name) mailer = self.mailers.get(app_name)
if not mailer: if not mailer:
@ -62,7 +70,7 @@ class PusherFactory:
self.mailers[app_name] = mailer self.mailers[app_name] = mailer
return EmailPusher(self.hs, pusherdict, mailer) return EmailPusher(self.hs, pusherdict, mailer)
def _app_name_from_pusherdict(self, pusherdict): def _app_name_from_pusherdict(self, pusherdict: Dict[str, Any]) -> str:
data = pusherdict["data"] data = pusherdict["data"]
if isinstance(data, dict): if isinstance(data, dict):

View file

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Dict, Union from typing import TYPE_CHECKING, Any, Dict, Optional
from prometheus_client import Gauge from prometheus_client import Gauge
@ -23,9 +23,7 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process, run_as_background_process,
wrap_as_background_process, wrap_as_background_process,
) )
from synapse.push import PusherConfigException from synapse.push import Pusher, PusherConfigException
from synapse.push.emailpusher import EmailPusher
from synapse.push.httppusher import HttpPusher
from synapse.push.pusher import PusherFactory from synapse.push.pusher import PusherFactory
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
@ -77,7 +75,7 @@ class PusherPool:
self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering() self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering()
# map from user id to app_id:pushkey to pusher # map from user id to app_id:pushkey to pusher
self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]] self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
def start(self): def start(self):
"""Starts the pushers off in a background process. """Starts the pushers off in a background process.
@ -99,11 +97,11 @@ class PusherPool:
lang, lang,
data, data,
profile_tag="", profile_tag="",
): ) -> Optional[Pusher]:
"""Creates a new pusher and adds it to the pool """Creates a new pusher and adds it to the pool
Returns: Returns:
EmailPusher|HttpPusher The newly created pusher.
""" """
time_now_msec = self.clock.time_msec() time_now_msec = self.clock.time_msec()
@ -267,17 +265,19 @@ class PusherPool:
except Exception: except Exception:
logger.exception("Exception in pusher on_new_receipts") logger.exception("Exception in pusher on_new_receipts")
async def start_pusher_by_id(self, app_id, pushkey, user_id): async def start_pusher_by_id(
self, app_id: str, pushkey: str, user_id: str
) -> Optional[Pusher]:
"""Look up the details for the given pusher, and start it """Look up the details for the given pusher, and start it
Returns: Returns:
EmailPusher|HttpPusher|None: The pusher started, if any The pusher started, if any
""" """
if not self._should_start_pushers: if not self._should_start_pushers:
return return None
if not self._pusher_shard_config.should_handle(self._instance_name, user_id): if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return return None
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
@ -303,19 +303,19 @@ class PusherPool:
logger.info("Started pushers") logger.info("Started pushers")
async def _start_pusher(self, pusherdict): async def _start_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
"""Start the given pusher """Start the given pusher
Args: Args:
pusherdict (dict): dict with the values pulled from the db table pusherdict: dict with the values pulled from the db table
Returns: Returns:
EmailPusher|HttpPusher The newly created pusher or None.
""" """
if not self._pusher_shard_config.should_handle( if not self._pusher_shard_config.should_handle(
self._instance_name, pusherdict["user_name"] self._instance_name, pusherdict["user_name"]
): ):
return return None
try: try:
p = self.pusher_factory.create_pusher(pusherdict) p = self.pusher_factory.create_pusher(pusherdict)
@ -328,15 +328,15 @@ class PusherPool:
pusherdict.get("pushkey"), pusherdict.get("pushkey"),
e, e,
) )
return return None
except Exception: except Exception:
logger.exception( logger.exception(
"Couldn't start pusher id %i: caught Exception", pusherdict["id"], "Couldn't start pusher id %i: caught Exception", pusherdict["id"],
) )
return return None
if not p: if not p:
return return None
appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"]) appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])