Generics for ObservableDeferred (#10491)

Now that `Deferred` is a generic class, let's update `ObeservableDeferred` to
follow suit.
This commit is contained in:
Richard van der Hoff 2021-07-28 20:55:50 +01:00 committed by GitHub
parent d0b294ad97
commit 858363d0b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 9 deletions

1
changelog.d/10491.misc Normal file
View file

@ -0,0 +1 @@
Improve type annotations for `ObservableDeferred`.

View file

@ -111,8 +111,9 @@ class _NotifierUserStream:
self.last_notified_token = current_token self.last_notified_token = current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
with PreserveLoggingContext(): self.notify_deferred: ObservableDeferred[StreamToken] = ObservableDeferred(
self.notify_deferred = ObservableDeferred(defer.Deferred()) defer.Deferred()
)
def notify( def notify(
self, self,

View file

@ -170,7 +170,9 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
end_item = queue[-1] end_item = queue[-1]
else: else:
# need to make a new queue item # need to make a new queue item
deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
defer.Deferred(), consumeErrors=True
)
end_item = _EventPersistQueueItem( end_item = _EventPersistQueueItem(
events_and_contexts=[], events_and_contexts=[],

View file

@ -23,6 +23,7 @@ from typing import (
Awaitable, Awaitable,
Callable, Callable,
Dict, Dict,
Generic,
Hashable, Hashable,
Iterable, Iterable,
List, List,
@ -39,6 +40,7 @@ from twisted.internet import defer
from twisted.internet.defer import CancelledError from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import IReactorTime from twisted.internet.interfaces import IReactorTime
from twisted.python import failure from twisted.python import failure
from twisted.python.failure import Failure
from synapse.logging.context import ( from synapse.logging.context import (
PreserveLoggingContext, PreserveLoggingContext,
@ -52,7 +54,7 @@ logger = logging.getLogger(__name__)
_T = TypeVar("_T") _T = TypeVar("_T")
class ObservableDeferred: class ObservableDeferred(Generic[_T]):
"""Wraps a deferred object so that we can add observer deferreds. These """Wraps a deferred object so that we can add observer deferreds. These
observer deferreds do not affect the callback chain of the original observer deferreds do not affect the callback chain of the original
deferred. deferred.
@ -70,7 +72,7 @@ class ObservableDeferred:
__slots__ = ["_deferred", "_observers", "_result"] __slots__ = ["_deferred", "_observers", "_result"]
def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False): def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None) object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", set()) object.__setattr__(self, "_observers", set())
@ -115,7 +117,7 @@ class ObservableDeferred:
deferred.addCallbacks(callback, errback) deferred.addCallbacks(callback, errback)
def observe(self) -> defer.Deferred: def observe(self) -> "defer.Deferred[_T]":
"""Observe the underlying deferred. """Observe the underlying deferred.
This returns a brand new deferred that is resolved when the underlying This returns a brand new deferred that is resolved when the underlying
@ -123,7 +125,7 @@ class ObservableDeferred:
effect the underlying deferred. effect the underlying deferred.
""" """
if not self._result: if not self._result:
d: "defer.Deferred[Any]" = defer.Deferred() d: "defer.Deferred[_T]" = defer.Deferred()
def remove(r): def remove(r):
self._observers.discard(d) self._observers.discard(d)
@ -137,7 +139,7 @@ class ObservableDeferred:
success, res = self._result success, res = self._result
return defer.succeed(res) if success else defer.fail(res) return defer.succeed(res) if success else defer.fail(res)
def observers(self) -> List[defer.Deferred]: def observers(self) -> "List[defer.Deferred[_T]]":
return self._observers return self._observers
def has_called(self) -> bool: def has_called(self) -> bool:
@ -146,7 +148,7 @@ class ObservableDeferred:
def has_succeeded(self) -> bool: def has_succeeded(self) -> bool:
return self._result is not None and self._result[0] is True return self._result is not None and self._result[0] is True
def get_result(self) -> Any: def get_result(self) -> Union[_T, Failure]:
return self._result[1] return self._result[1]
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any: