Add missing type hints to synapse.api. (#11109)

* Convert UserPresenceState to attrs.
* Remove args/kwargs from error classes and explicitly pass msg/errorcode.
This commit is contained in:
Patrick Cloke 2021-10-18 15:01:10 -04:00 committed by GitHub
parent cc33d9eee2
commit 3ab55d43bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 84 additions and 99 deletions

1
changelog.d/11109.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to `synapse.api` module.

View file

@ -100,6 +100,9 @@ files =
tests/util/test_itertools.py, tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py tests/util/test_stream_change_cache.py
[mypy-synapse.api.*]
disallow_untyped_defs = True
[mypy-synapse.events.*] [mypy-synapse.events.*]
disallow_untyped_defs = True disallow_untyped_defs = True

View file

@ -245,7 +245,7 @@ class Auth:
async def validate_appservice_can_control_user_id( async def validate_appservice_can_control_user_id(
self, app_service: ApplicationService, user_id: str self, app_service: ApplicationService, user_id: str
): ) -> None:
"""Validates that the app service is allowed to control """Validates that the app service is allowed to control
the given user. the given user.
@ -618,5 +618,13 @@ class Auth:
% (user_id, room_id), % (user_id, room_id),
) )
async def check_auth_blocking(self, *args, **kwargs) -> None: async def check_auth_blocking(
await self._auth_blocking.check_auth_blocking(*args, **kwargs) self,
user_id: Optional[str] = None,
threepid: Optional[dict] = None,
user_type: Optional[str] = None,
requester: Optional[Requester] = None,
) -> None:
await self._auth_blocking.check_auth_blocking(
user_id=user_id, threepid=threepid, user_type=user_type, requester=requester
)

View file

@ -18,7 +18,7 @@
import logging import logging
import typing import typing
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from twisted.web import http from twisted.web import http
@ -143,7 +143,7 @@ class SynapseError(CodeMessageException):
super().__init__(code, msg) super().__init__(code, msg)
self.errcode = errcode self.errcode = errcode
def error_dict(self): def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode) return cs_error(self.msg, self.errcode)
@ -175,7 +175,7 @@ class ProxiedRequestError(SynapseError):
else: else:
self._additional_fields = dict(additional_fields) self._additional_fields = dict(additional_fields)
def error_dict(self): def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, **self._additional_fields) return cs_error(self.msg, self.errcode, **self._additional_fields)
@ -196,7 +196,7 @@ class ConsentNotGivenError(SynapseError):
) )
self._consent_uri = consent_uri self._consent_uri = consent_uri
def error_dict(self): def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri) return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri)
@ -262,14 +262,10 @@ class InteractiveAuthIncompleteError(Exception):
class UnrecognizedRequestError(SynapseError): class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make""" """An error indicating we don't understand the request you're trying to make"""
def __init__(self, *args, **kwargs): def __init__(
if "errcode" not in kwargs: self, msg: str = "Unrecognized request", errcode: str = Codes.UNRECOGNIZED
kwargs["errcode"] = Codes.UNRECOGNIZED ):
if len(args) == 0: super().__init__(400, msg, errcode)
message = "Unrecognized request"
else:
message = args[0]
super().__init__(400, message, **kwargs)
class NotFoundError(SynapseError): class NotFoundError(SynapseError):
@ -284,10 +280,8 @@ class AuthError(SynapseError):
other poorly-defined times. other poorly-defined times.
""" """
def __init__(self, *args, **kwargs): def __init__(self, code: int, msg: str, errcode: str = Codes.FORBIDDEN):
if "errcode" not in kwargs: super().__init__(code, msg, errcode)
kwargs["errcode"] = Codes.FORBIDDEN
super().__init__(*args, **kwargs)
class InvalidClientCredentialsError(SynapseError): class InvalidClientCredentialsError(SynapseError):
@ -321,7 +315,7 @@ class InvalidClientTokenError(InvalidClientCredentialsError):
super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN") super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN")
self._soft_logout = soft_logout self._soft_logout = soft_logout
def error_dict(self): def error_dict(self) -> "JsonDict":
d = super().error_dict() d = super().error_dict()
d["soft_logout"] = self._soft_logout d["soft_logout"] = self._soft_logout
return d return d
@ -345,7 +339,7 @@ class ResourceLimitError(SynapseError):
self.limit_type = limit_type self.limit_type = limit_type
super().__init__(code, msg, errcode=errcode) super().__init__(code, msg, errcode=errcode)
def error_dict(self): def error_dict(self) -> "JsonDict":
return cs_error( return cs_error(
self.msg, self.msg,
self.errcode, self.errcode,
@ -357,32 +351,17 @@ class ResourceLimitError(SynapseError):
class EventSizeError(SynapseError): class EventSizeError(SynapseError):
"""An error raised when an event is too big.""" """An error raised when an event is too big."""
def __init__(self, *args, **kwargs): def __init__(self, msg: str):
if "errcode" not in kwargs: super().__init__(413, msg, Codes.TOO_LARGE)
kwargs["errcode"] = Codes.TOO_LARGE
super().__init__(413, *args, **kwargs)
class EventStreamError(SynapseError):
"""An error raised when there a problem with the event stream."""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.BAD_PAGINATION
super().__init__(*args, **kwargs)
class LoginError(SynapseError): class LoginError(SynapseError):
"""An error raised when there was a problem logging in.""" """An error raised when there was a problem logging in."""
pass
class StoreError(SynapseError): class StoreError(SynapseError):
"""An error raised when there was a problem storing some data.""" """An error raised when there was a problem storing some data."""
pass
class InvalidCaptchaError(SynapseError): class InvalidCaptchaError(SynapseError):
def __init__( def __init__(
@ -395,7 +374,7 @@ class InvalidCaptchaError(SynapseError):
super().__init__(code, msg, errcode) super().__init__(code, msg, errcode)
self.error_url = error_url self.error_url = error_url
def error_dict(self): def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, error_url=self.error_url) return cs_error(self.msg, self.errcode, error_url=self.error_url)
@ -412,7 +391,7 @@ class LimitExceededError(SynapseError):
super().__init__(code, msg, errcode) super().__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms self.retry_after_ms = retry_after_ms
def error_dict(self): def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)
@ -443,10 +422,8 @@ class UnsupportedRoomVersionError(SynapseError):
class ThreepidValidationError(SynapseError): class ThreepidValidationError(SynapseError):
"""An error raised when there was a problem authorising an event.""" """An error raised when there was a problem authorising an event."""
def __init__(self, *args, **kwargs): def __init__(self, msg: str, errcode: str = Codes.FORBIDDEN):
if "errcode" not in kwargs: super().__init__(400, msg, errcode)
kwargs["errcode"] = Codes.FORBIDDEN
super().__init__(*args, **kwargs)
class IncompatibleRoomVersionError(SynapseError): class IncompatibleRoomVersionError(SynapseError):
@ -466,7 +443,7 @@ class IncompatibleRoomVersionError(SynapseError):
self._room_version = room_version self._room_version = room_version
def error_dict(self): def error_dict(self) -> "JsonDict":
return cs_error(self.msg, self.errcode, room_version=self._room_version) return cs_error(self.msg, self.errcode, room_version=self._room_version)
@ -494,7 +471,7 @@ class RequestSendFailed(RuntimeError):
errors (like programming errors). errors (like programming errors).
""" """
def __init__(self, inner_exception, can_retry): def __init__(self, inner_exception: BaseException, can_retry: bool):
super().__init__( super().__init__(
"Failed to send request: %s: %s" "Failed to send request: %s: %s"
% (type(inner_exception).__name__, inner_exception) % (type(inner_exception).__name__, inner_exception)
@ -503,7 +480,7 @@ class RequestSendFailed(RuntimeError):
self.can_retry = can_retry self.can_retry = can_retry
def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs): def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict":
"""Utility method for constructing an error response for client-server """Utility method for constructing an error response for client-server
interactions. interactions.
@ -551,7 +528,7 @@ class FederationError(RuntimeError):
msg = "%s %s: %s" % (level, code, reason) msg = "%s %s: %s" % (level, code, reason)
super().__init__(msg) super().__init__(msg)
def get_dict(self): def get_dict(self) -> "JsonDict":
return { return {
"level": self.level, "level": self.level,
"code": self.code, "code": self.code,
@ -580,7 +557,7 @@ class HttpResponseException(CodeMessageException):
super().__init__(code, msg) super().__init__(code, msg)
self.response = response self.response = response
def to_synapse_error(self): def to_synapse_error(self) -> SynapseError:
"""Make a SynapseError based on an HTTPResponseException """Make a SynapseError based on an HTTPResponseException
This is useful when a proxied request has failed, and we need to This is useful when a proxied request has failed, and we need to

View file

@ -231,24 +231,24 @@ class FilterCollection:
def include_redundant_members(self) -> bool: def include_redundant_members(self) -> bool:
return self._room_state_filter.include_redundant_members() return self._room_state_filter.include_redundant_members()
def filter_presence(self, events): def filter_presence(
self, events: Iterable[UserPresenceState]
) -> List[UserPresenceState]:
return self._presence_filter.filter(events) return self._presence_filter.filter(events)
def filter_account_data(self, events): def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._account_data.filter(events) return self._account_data.filter(events)
def filter_room_state(self, events): def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
return self._room_state_filter.filter(self._room_filter.filter(events)) return self._room_state_filter.filter(self._room_filter.filter(events))
def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: def filter_room_timeline(self, events: Iterable[EventBase]) -> List[EventBase]:
return self._room_timeline_filter.filter(self._room_filter.filter(events)) return self._room_timeline_filter.filter(self._room_filter.filter(events))
def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._room_ephemeral_filter.filter(self._room_filter.filter(events)) return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
def filter_room_account_data( def filter_room_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
self, events: Iterable[FilterEvent]
) -> List[FilterEvent]:
return self._room_account_data.filter(self._room_filter.filter(events)) return self._room_account_data.filter(self._room_filter.filter(events))
def blocks_all_presence(self) -> bool: def blocks_all_presence(self) -> bool:
@ -309,7 +309,7 @@ class Filter:
# except for presence which actually gets passed around as its own # except for presence which actually gets passed around as its own
# namedtuple type. # namedtuple type.
if isinstance(event, UserPresenceState): if isinstance(event, UserPresenceState):
sender = event.user_id sender: Optional[str] = event.user_id
room_id = None room_id = None
ev_type = "m.presence" ev_type = "m.presence"
contains_url = False contains_url = False

View file

@ -12,49 +12,48 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import namedtuple from typing import Any, Optional
import attr
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.types import JsonDict
class UserPresenceState( @attr.s(slots=True, frozen=True, auto_attribs=True)
namedtuple( class UserPresenceState:
"UserPresenceState",
(
"user_id",
"state",
"last_active_ts",
"last_federation_update_ts",
"last_user_sync_ts",
"status_msg",
"currently_active",
),
)
):
"""Represents the current presence state of the user. """Represents the current presence state of the user.
user_id (str) user_id
last_active (int): Time in msec that the user last interacted with server. last_active: Time in msec that the user last interacted with server.
last_federation_update (int): Time in msec since either a) we sent a presence last_federation_update: Time in msec since either a) we sent a presence
update to other servers or b) we received a presence update, depending update to other servers or b) we received a presence update, depending
on if is a local user or not. on if is a local user or not.
last_user_sync (int): Time in msec that the user last *completed* a sync last_user_sync: Time in msec that the user last *completed* a sync
(or event stream). (or event stream).
status_msg (str): User set status message. status_msg: User set status message.
""" """
def as_dict(self): user_id: str
return dict(self._asdict()) state: str
last_active_ts: int
last_federation_update_ts: int
last_user_sync_ts: int
status_msg: Optional[str]
currently_active: bool
def as_dict(self) -> JsonDict:
return attr.asdict(self)
@staticmethod @staticmethod
def from_dict(d): def from_dict(d: JsonDict) -> "UserPresenceState":
return UserPresenceState(**d) return UserPresenceState(**d)
def copy_and_replace(self, **kwargs): def copy_and_replace(self, **kwargs: Any) -> "UserPresenceState":
return self._replace(**kwargs) return attr.evolve(self, **kwargs)
@classmethod @classmethod
def default(cls, user_id): def default(cls, user_id: str) -> "UserPresenceState":
"""Returns a default presence state.""" """Returns a default presence state."""
return cls( return cls(
user_id=user_id, user_id=user_id,

View file

@ -161,7 +161,7 @@ class Ratelimiter:
return allowed, time_allowed return allowed, time_allowed
def _prune_message_counts(self, time_now_s: float): def _prune_message_counts(self, time_now_s: float) -> None:
"""Remove message count entries that have not exceeded their defined """Remove message count entries that have not exceeded their defined
rate_hz limit rate_hz limit
@ -190,7 +190,7 @@ class Ratelimiter:
update: bool = True, update: bool = True,
n_actions: int = 1, n_actions: int = 1,
_time_now_s: Optional[float] = None, _time_now_s: Optional[float] = None,
): ) -> None:
"""Checks if an action can be performed. If not, raises a LimitExceededError """Checks if an action can be performed. If not, raises a LimitExceededError
Checks if the user has ratelimiting disabled in the database by looking Checks if the user has ratelimiting disabled in the database by looking

View file

@ -19,6 +19,7 @@ from hashlib import sha256
from urllib.parse import urlencode from urllib.parse import urlencode
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.config.homeserver import HomeServerConfig
SYNAPSE_CLIENT_API_PREFIX = "/_synapse/client" SYNAPSE_CLIENT_API_PREFIX = "/_synapse/client"
CLIENT_API_PREFIX = "/_matrix/client" CLIENT_API_PREFIX = "/_matrix/client"
@ -34,11 +35,7 @@ LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
class ConsentURIBuilder: class ConsentURIBuilder:
def __init__(self, hs_config): def __init__(self, hs_config: HomeServerConfig):
"""
Args:
hs_config (synapse.config.homeserver.HomeServerConfig):
"""
if hs_config.key.form_secret is None: if hs_config.key.form_secret is None:
raise ConfigError("form_secret not set in config") raise ConfigError("form_secret not set in config")
if hs_config.server.public_baseurl is None: if hs_config.server.public_baseurl is None:
@ -47,15 +44,15 @@ class ConsentURIBuilder:
self._hmac_secret = hs_config.key.form_secret.encode("utf-8") self._hmac_secret = hs_config.key.form_secret.encode("utf-8")
self._public_baseurl = hs_config.server.public_baseurl self._public_baseurl = hs_config.server.public_baseurl
def build_user_consent_uri(self, user_id): def build_user_consent_uri(self, user_id: str) -> str:
"""Build a URI which we can give to the user to do their privacy """Build a URI which we can give to the user to do their privacy
policy consent policy consent
Args: Args:
user_id (str): mxid or username of user user_id: mxid or username of user
Returns Returns
(str) the URI where the user can do consent The URI where the user can do consent
""" """
mac = hmac.new( mac = hmac.new(
key=self._hmac_secret, msg=user_id.encode("ascii"), digestmod=sha256 key=self._hmac_secret, msg=user_id.encode("ascii"), digestmod=sha256

View file

@ -1489,7 +1489,7 @@ def format_user_presence_state(
The "user_id" is optional so that this function can be used to format presence The "user_id" is optional so that this function can be used to format presence
updates for client /sync responses and for federation /send requests. updates for client /sync responses and for federation /send requests.
""" """
content = {"presence": state.state} content: JsonDict = {"presence": state.state}
if include_user_id: if include_user_id:
content["user_id"] = state.user_id content["user_id"] = state.user_id
if state.last_active_ts: if state.last_active_ts:

View file

@ -2237,7 +2237,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
# accident. # accident.
row = {"client_secret": None, "validated_at": None} row = {"client_secret": None, "validated_at": None}
else: else:
raise ThreepidValidationError(400, "Unknown session_id") raise ThreepidValidationError("Unknown session_id")
retrieved_client_secret = row["client_secret"] retrieved_client_secret = row["client_secret"]
validated_at = row["validated_at"] validated_at = row["validated_at"]
@ -2252,14 +2252,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
if not row: if not row:
raise ThreepidValidationError( raise ThreepidValidationError(
400, "Validation token not found or has expired" "Validation token not found or has expired"
) )
expires = row["expires"] expires = row["expires"]
next_link = row["next_link"] next_link = row["next_link"]
if retrieved_client_secret != client_secret: if retrieved_client_secret != client_secret:
raise ThreepidValidationError( raise ThreepidValidationError(
400, "This client_secret does not match the provided session_id" "This client_secret does not match the provided session_id"
) )
# If the session is already validated, no need to revalidate # If the session is already validated, no need to revalidate
@ -2268,7 +2268,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
if expires <= current_ts: if expires <= current_ts:
raise ThreepidValidationError( raise ThreepidValidationError(
400, "This token has expired. Please request a new one" "This token has expired. Please request a new one"
) )
# Looks good. Validate the session # Looks good. Validate the session