Add type hints to groups code. (#9393)

This commit is contained in:
Patrick Cloke 2021-02-17 08:41:47 -05:00 committed by GitHub
parent e1071fd625
commit d2f0ec12d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 341 additions and 124 deletions

View file

@ -1 +1 @@
Assert a maximum length for the `client_secret` parameter for spec compliance. Assert a maximum length for some parameters for spec compliance.

1
changelog.d/9393.bugfix Normal file
View file

@ -0,0 +1 @@
Assert a maximum length for some parameters for spec compliance.

View file

@ -23,6 +23,7 @@ files =
synapse/events/validator.py, synapse/events/validator.py,
synapse/events/spamcheck.py, synapse/events/spamcheck.py,
synapse/federation, synapse/federation,
synapse/groups,
synapse/handlers, synapse/handlers,
synapse/http/client.py, synapse/http/client.py,
synapse/http/federation/matrix_federation_agent.py, synapse/http/federation/matrix_federation_agent.py,

View file

@ -27,6 +27,11 @@ MAX_ALIAS_LENGTH = 255
# the maximum length for a user id is 255 characters # the maximum length for a user id is 255 characters
MAX_USERID_LENGTH = 255 MAX_USERID_LENGTH = 255
# The maximum length for a group id is 255 characters
MAX_GROUPID_LENGTH = 255
MAX_GROUP_CATEGORYID_LENGTH = 255
MAX_GROUP_ROLEID_LENGTH = 255
class Membership: class Membership:

View file

@ -21,6 +21,7 @@ import re
from typing import Optional, Tuple, Type from typing import Optional, Tuple, Type
import synapse import synapse
from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH
from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.api.urls import ( from synapse.api.urls import (
@ -1118,7 +1119,17 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
raise SynapseError(403, "requester_user_id doesn't match origin") raise SynapseError(403, "requester_user_id doesn't match origin")
if category_id == "": if category_id == "":
raise SynapseError(400, "category_id cannot be empty string") raise SynapseError(
400, "category_id cannot be empty string", Codes.INVALID_PARAM
)
if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
raise SynapseError(
400,
"category_id may not be longer than %s characters"
% (MAX_GROUP_CATEGORYID_LENGTH,),
Codes.INVALID_PARAM,
)
resp = await self.handler.update_group_summary_room( resp = await self.handler.update_group_summary_room(
group_id, group_id,
@ -1184,6 +1195,14 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
if category_id == "": if category_id == "":
raise SynapseError(400, "category_id cannot be empty string") raise SynapseError(400, "category_id cannot be empty string")
if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
raise SynapseError(
400,
"category_id may not be longer than %s characters"
% (MAX_GROUP_CATEGORYID_LENGTH,),
Codes.INVALID_PARAM,
)
resp = await self.handler.upsert_group_category( resp = await self.handler.upsert_group_category(
group_id, requester_user_id, category_id, content group_id, requester_user_id, category_id, content
) )
@ -1240,7 +1259,17 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
raise SynapseError(403, "requester_user_id doesn't match origin") raise SynapseError(403, "requester_user_id doesn't match origin")
if role_id == "": if role_id == "":
raise SynapseError(400, "role_id cannot be empty string") raise SynapseError(
400, "role_id cannot be empty string", Codes.INVALID_PARAM
)
if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
raise SynapseError(
400,
"role_id may not be longer than %s characters"
% (MAX_GROUP_ROLEID_LENGTH,),
Codes.INVALID_PARAM,
)
resp = await self.handler.update_group_role( resp = await self.handler.update_group_role(
group_id, requester_user_id, role_id, content group_id, requester_user_id, role_id, content
@ -1285,6 +1314,14 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
if role_id == "": if role_id == "":
raise SynapseError(400, "role_id cannot be empty string") raise SynapseError(400, "role_id cannot be empty string")
if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
raise SynapseError(
400,
"role_id may not be longer than %s characters"
% (MAX_GROUP_ROLEID_LENGTH,),
Codes.INVALID_PARAM,
)
resp = await self.handler.update_group_summary_user( resp = await self.handler.update_group_summary_user(
group_id, group_id,
requester_user_id, requester_user_id,

View file

@ -37,13 +37,16 @@ An attestation is a signed blob of json that looks like:
import logging import logging
import random import random
from typing import Tuple from typing import TYPE_CHECKING, Optional, Tuple
from signedjson.sign import sign_json from signedjson.sign import sign_json
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id from synapse.types import JsonDict, get_domain_from_id
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -63,15 +66,19 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
class GroupAttestationSigning: class GroupAttestationSigning:
"""Creates and verifies group attestations.""" """Creates and verifies group attestations."""
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.server_name = hs.hostname self.server_name = hs.hostname
self.signing_key = hs.signing_key self.signing_key = hs.signing_key
async def verify_attestation( async def verify_attestation(
self, attestation, group_id, user_id, server_name=None self,
): attestation: JsonDict,
group_id: str,
user_id: str,
server_name: Optional[str] = None,
) -> None:
"""Verifies that the given attestation matches the given parameters. """Verifies that the given attestation matches the given parameters.
An optional server_name can be supplied to explicitly set which server's An optional server_name can be supplied to explicitly set which server's
@ -100,16 +107,18 @@ class GroupAttestationSigning:
if valid_until_ms < now: if valid_until_ms < now:
raise SynapseError(400, "Attestation expired") raise SynapseError(400, "Attestation expired")
assert server_name is not None
await self.keyring.verify_json_for_server( await self.keyring.verify_json_for_server(
server_name, attestation, now, "Group attestation" server_name, attestation, now, "Group attestation"
) )
def create_attestation(self, group_id, user_id): def create_attestation(self, group_id: str, user_id: str) -> JsonDict:
"""Create an attestation for the group_id and user_id with default """Create an attestation for the group_id and user_id with default
validity length. validity length.
""" """
validity_period = DEFAULT_ATTESTATION_LENGTH_MS validity_period = DEFAULT_ATTESTATION_LENGTH_MS * random.uniform(
validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER) *DEFAULT_ATTESTATION_JITTER
)
valid_until_ms = int(self.clock.time_msec() + validity_period) valid_until_ms = int(self.clock.time_msec() + validity_period)
return sign_json( return sign_json(
@ -126,7 +135,7 @@ class GroupAttestationSigning:
class GroupAttestionRenewer: class GroupAttestionRenewer:
"""Responsible for sending and receiving attestation updates.""" """Responsible for sending and receiving attestation updates."""
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.assestations = hs.get_groups_attestation_signing() self.assestations = hs.get_groups_attestation_signing()
@ -139,7 +148,9 @@ class GroupAttestionRenewer:
self._start_renew_attestations, 30 * 60 * 1000 self._start_renew_attestations, 30 * 60 * 1000
) )
async def on_renew_attestation(self, group_id, user_id, content): async def on_renew_attestation(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""When a remote updates an attestation""" """When a remote updates an attestation"""
attestation = content["attestation"] attestation = content["attestation"]
@ -154,10 +165,10 @@ class GroupAttestionRenewer:
return {} return {}
def _start_renew_attestations(self): def _start_renew_attestations(self) -> None:
return run_as_background_process("renew_attestations", self._renew_attestations) return run_as_background_process("renew_attestations", self._renew_attestations)
async def _renew_attestations(self): async def _renew_attestations(self) -> None:
"""Called periodically to check if we need to update any of our attestations""" """Called periodically to check if we need to update any of our attestations"""
now = self.clock.time_msec() now = self.clock.time_msec()
@ -166,7 +177,7 @@ class GroupAttestionRenewer:
now + UPDATE_ATTESTATION_TIME_MS now + UPDATE_ATTESTATION_TIME_MS
) )
async def _renew_attestation(group_user: Tuple[str, str]): async def _renew_attestation(group_user: Tuple[str, str]) -> None:
group_id, user_id = group_user group_id, user_id = group_user
try: try:
if not self.is_mine_id(group_id): if not self.is_mine_id(group_id):

View file

@ -16,12 +16,17 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id from synapse.types import GroupID, JsonDict, RoomID, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,7 +44,7 @@ MAX_LONG_DESC_LEN = 10000
class GroupsServerWorkerHandler: class GroupsServerWorkerHandler:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.room_list_handler = hs.get_room_list_handler() self.room_list_handler = hs.get_room_list_handler()
@ -54,16 +59,21 @@ class GroupsServerWorkerHandler:
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
async def check_group_is_ours( async def check_group_is_ours(
self, group_id, requester_user_id, and_exists=False, and_is_admin=None self,
): group_id: str,
requester_user_id: str,
and_exists: bool = False,
and_is_admin: Optional[str] = None,
) -> Optional[dict]:
"""Check that the group is ours, and optionally if it exists. """Check that the group is ours, and optionally if it exists.
If group does exist then return group. If group does exist then return group.
Args: Args:
group_id (str) group_id: The group ID to check.
and_exists (bool): whether to also check if group exists requester_user_id: The user ID of the requester.
and_is_admin (str): whether to also check if given str is a user_id and_exists: whether to also check if group exists
and_is_admin: whether to also check if given str is a user_id
that is an admin that is an admin
""" """
if not self.is_mine_id(group_id): if not self.is_mine_id(group_id):
@ -86,7 +96,9 @@ class GroupsServerWorkerHandler:
return group return group
async def get_group_summary(self, group_id, requester_user_id): async def get_group_summary(
self, group_id: str, requester_user_id: str
) -> JsonDict:
"""Get the summary for a group as seen by requester_user_id. """Get the summary for a group as seen by requester_user_id.
The group summary consists of the profile of the room, and a curated The group summary consists of the profile of the room, and a curated
@ -119,6 +131,8 @@ class GroupsServerWorkerHandler:
entry = await self.room_list_handler.generate_room_entry( entry = await self.room_list_handler.generate_room_entry(
room_id, len(joined_users), with_alias=False, allow_private=True room_id, len(joined_users), with_alias=False, allow_private=True
) )
if entry is None:
continue
entry = dict(entry) # so we don't change what's cached entry = dict(entry) # so we don't change what's cached
entry.pop("room_id", None) entry.pop("room_id", None)
@ -126,22 +140,22 @@ class GroupsServerWorkerHandler:
rooms.sort(key=lambda e: e.get("order", 0)) rooms.sort(key=lambda e: e.get("order", 0))
for entry in users: for user in users:
user_id = entry["user_id"] user_id = user["user_id"]
if not self.is_mine_id(requester_user_id): if not self.is_mine_id(requester_user_id):
attestation = await self.store.get_remote_attestation(group_id, user_id) attestation = await self.store.get_remote_attestation(group_id, user_id)
if not attestation: if not attestation:
continue continue
entry["attestation"] = attestation user["attestation"] = attestation
else: else:
entry["attestation"] = self.attestations.create_attestation( user["attestation"] = self.attestations.create_attestation(
group_id, user_id group_id, user_id
) )
user_profile = await self.profile_handler.get_profile_from_cache(user_id) user_profile = await self.profile_handler.get_profile_from_cache(user_id)
entry.update(user_profile) user.update(user_profile)
users.sort(key=lambda e: e.get("order", 0)) users.sort(key=lambda e: e.get("order", 0))
@ -164,40 +178,43 @@ class GroupsServerWorkerHandler:
"user": membership_info, "user": membership_info,
} }
async def get_group_categories(self, group_id, requester_user_id): async def get_group_categories(
self, group_id: str, requester_user_id: str
) -> JsonDict:
"""Get all categories in a group (as seen by user)""" """Get all categories in a group (as seen by user)"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
categories = await self.store.get_group_categories(group_id=group_id) categories = await self.store.get_group_categories(group_id=group_id)
return {"categories": categories} return {"categories": categories}
async def get_group_category(self, group_id, requester_user_id, category_id): async def get_group_category(
self, group_id: str, requester_user_id: str, category_id: str
) -> JsonDict:
"""Get a specific category in a group (as seen by user)""" """Get a specific category in a group (as seen by user)"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
res = await self.store.get_group_category( return await self.store.get_group_category(
group_id=group_id, category_id=category_id group_id=group_id, category_id=category_id
) )
logger.info("group %s", res) async def get_group_roles(self, group_id: str, requester_user_id: str) -> JsonDict:
return res
async def get_group_roles(self, group_id, requester_user_id):
"""Get all roles in a group (as seen by user)""" """Get all roles in a group (as seen by user)"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
roles = await self.store.get_group_roles(group_id=group_id) roles = await self.store.get_group_roles(group_id=group_id)
return {"roles": roles} return {"roles": roles}
async def get_group_role(self, group_id, requester_user_id, role_id): async def get_group_role(
self, group_id: str, requester_user_id: str, role_id: str
) -> JsonDict:
"""Get a specific role in a group (as seen by user)""" """Get a specific role in a group (as seen by user)"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
res = await self.store.get_group_role(group_id=group_id, role_id=role_id) return await self.store.get_group_role(group_id=group_id, role_id=role_id)
return res
async def get_group_profile(self, group_id, requester_user_id): async def get_group_profile(
self, group_id: str, requester_user_id: str
) -> JsonDict:
"""Get the group profile as seen by requester_user_id""" """Get the group profile as seen by requester_user_id"""
await self.check_group_is_ours(group_id, requester_user_id) await self.check_group_is_ours(group_id, requester_user_id)
@ -219,7 +236,9 @@ class GroupsServerWorkerHandler:
else: else:
raise SynapseError(404, "Unknown group") raise SynapseError(404, "Unknown group")
async def get_users_in_group(self, group_id, requester_user_id): async def get_users_in_group(
self, group_id: str, requester_user_id: str
) -> JsonDict:
"""Get the users in group as seen by requester_user_id. """Get the users in group as seen by requester_user_id.
The ordering is arbitrary at the moment The ordering is arbitrary at the moment
@ -268,7 +287,9 @@ class GroupsServerWorkerHandler:
return {"chunk": chunk, "total_user_count_estimate": len(user_results)} return {"chunk": chunk, "total_user_count_estimate": len(user_results)}
async def get_invited_users_in_group(self, group_id, requester_user_id): async def get_invited_users_in_group(
self, group_id: str, requester_user_id: str
) -> JsonDict:
"""Get the users that have been invited to a group as seen by requester_user_id. """Get the users that have been invited to a group as seen by requester_user_id.
The ordering is arbitrary at the moment The ordering is arbitrary at the moment
@ -298,7 +319,9 @@ class GroupsServerWorkerHandler:
return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}
async def get_rooms_in_group(self, group_id, requester_user_id): async def get_rooms_in_group(
self, group_id: str, requester_user_id: str
) -> JsonDict:
"""Get the rooms in group as seen by requester_user_id """Get the rooms in group as seen by requester_user_id
This returns rooms in order of decreasing number of joined users This returns rooms in order of decreasing number of joined users
@ -336,15 +359,20 @@ class GroupsServerWorkerHandler:
class GroupsServerHandler(GroupsServerWorkerHandler): class GroupsServerHandler(GroupsServerWorkerHandler):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
# Ensure attestations get renewed # Ensure attestations get renewed
hs.get_groups_attestation_renewer() hs.get_groups_attestation_renewer()
async def update_group_summary_room( async def update_group_summary_room(
self, group_id, requester_user_id, room_id, category_id, content self,
): group_id: str,
requester_user_id: str,
room_id: str,
category_id: str,
content: JsonDict,
) -> JsonDict:
"""Add/update a room to the group summary""" """Add/update a room to the group summary"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@ -367,8 +395,8 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def delete_group_summary_room( async def delete_group_summary_room(
self, group_id, requester_user_id, room_id, category_id self, group_id: str, requester_user_id: str, room_id: str, category_id: str
): ) -> JsonDict:
"""Remove a room from the summary""" """Remove a room from the summary"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@ -380,7 +408,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def set_group_join_policy(self, group_id, requester_user_id, content): async def set_group_join_policy(
self, group_id: str, requester_user_id: str, content: JsonDict
) -> JsonDict:
"""Sets the group join policy. """Sets the group join policy.
Currently supported policies are: Currently supported policies are:
@ -400,8 +430,8 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def update_group_category( async def update_group_category(
self, group_id, requester_user_id, category_id, content self, group_id: str, requester_user_id: str, category_id: str, content: JsonDict
): ) -> JsonDict:
"""Add/Update a group category""" """Add/Update a group category"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@ -419,7 +449,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def delete_group_category(self, group_id, requester_user_id, category_id): async def delete_group_category(
self, group_id: str, requester_user_id: str, category_id: str
) -> JsonDict:
"""Delete a group category""" """Delete a group category"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@ -431,7 +463,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def update_group_role(self, group_id, requester_user_id, role_id, content): async def update_group_role(
self, group_id: str, requester_user_id: str, role_id: str, content: JsonDict
) -> JsonDict:
"""Add/update a role in a group""" """Add/update a role in a group"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@ -447,7 +481,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def delete_group_role(self, group_id, requester_user_id, role_id): async def delete_group_role(
self, group_id: str, requester_user_id: str, role_id: str
) -> JsonDict:
"""Remove role from group""" """Remove role from group"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@ -458,8 +494,13 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def update_group_summary_user( async def update_group_summary_user(
self, group_id, requester_user_id, user_id, role_id, content self,
): group_id: str,
requester_user_id: str,
user_id: str,
role_id: str,
content: JsonDict,
) -> JsonDict:
"""Add/update a users entry in the group summary""" """Add/update a users entry in the group summary"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@ -480,8 +521,8 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def delete_group_summary_user( async def delete_group_summary_user(
self, group_id, requester_user_id, user_id, role_id self, group_id: str, requester_user_id: str, user_id: str, role_id: str
): ) -> JsonDict:
"""Remove a user from the group summary""" """Remove a user from the group summary"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@ -493,7 +534,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def update_group_profile(self, group_id, requester_user_id, content): async def update_group_profile(
self, group_id: str, requester_user_id: str, content: JsonDict
) -> None:
"""Update the group profile""" """Update the group profile"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@ -524,7 +567,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
await self.store.update_group_profile(group_id, profile) await self.store.update_group_profile(group_id, profile)
async def add_room_to_group(self, group_id, requester_user_id, room_id, content): async def add_room_to_group(
self, group_id: str, requester_user_id: str, room_id: str, content: JsonDict
) -> JsonDict:
"""Add room to group""" """Add room to group"""
RoomID.from_string(room_id) # Ensure valid room id RoomID.from_string(room_id) # Ensure valid room id
@ -539,8 +584,13 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def update_room_in_group( async def update_room_in_group(
self, group_id, requester_user_id, room_id, config_key, content self,
): group_id: str,
requester_user_id: str,
room_id: str,
config_key: str,
content: JsonDict,
) -> JsonDict:
"""Update room in group""" """Update room in group"""
RoomID.from_string(room_id) # Ensure valid room id RoomID.from_string(room_id) # Ensure valid room id
@ -559,7 +609,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def remove_room_from_group(self, group_id, requester_user_id, room_id): async def remove_room_from_group(
self, group_id: str, requester_user_id: str, room_id: str
) -> JsonDict:
"""Remove room from group""" """Remove room from group"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@ -569,12 +621,16 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def invite_to_group(self, group_id, user_id, requester_user_id, content): async def invite_to_group(
self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
) -> JsonDict:
"""Invite user to group""" """Invite user to group"""
group = await self.check_group_is_ours( group = await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
) )
if not group:
raise SynapseError(400, "Group does not exist", errcode=Codes.BAD_STATE)
# TODO: Check if user knocked # TODO: Check if user knocked
@ -597,6 +653,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
if self.hs.is_mine_id(user_id): if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler() groups_local = self.hs.get_groups_local_handler()
assert isinstance(
groups_local, GroupsLocalHandler
), "Workers cannot invites users to groups."
res = await groups_local.on_invite(group_id, user_id, content) res = await groups_local.on_invite(group_id, user_id, content)
local_attestation = None local_attestation = None
else: else:
@ -632,6 +691,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
local_attestation=local_attestation, local_attestation=local_attestation,
remote_attestation=remote_attestation, remote_attestation=remote_attestation,
) )
return {"state": "join"}
elif res["state"] == "invite": elif res["state"] == "invite":
await self.store.add_group_invite(group_id, user_id) await self.store.add_group_invite(group_id, user_id)
return {"state": "invite"} return {"state": "invite"}
@ -640,13 +700,17 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
else: else:
raise SynapseError(502, "Unknown state returned by HS") raise SynapseError(502, "Unknown state returned by HS")
async def _add_user(self, group_id, user_id, content): async def _add_user(
self, group_id: str, user_id: str, content: JsonDict
) -> Optional[JsonDict]:
"""Add a user to a group based on a content dict. """Add a user to a group based on a content dict.
See accept_invite, join_group. See accept_invite, join_group.
""" """
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(user_id):
local_attestation = self.attestations.create_attestation(group_id, user_id) local_attestation = self.attestations.create_attestation(
group_id, user_id
) # type: Optional[JsonDict]
remote_attestation = content["attestation"] remote_attestation = content["attestation"]
@ -670,7 +734,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return local_attestation return local_attestation
async def accept_invite(self, group_id, requester_user_id, content): async def accept_invite(
self, group_id: str, requester_user_id: str, content: JsonDict
) -> JsonDict:
"""User tries to accept an invite to the group. """User tries to accept an invite to the group.
This is different from them asking to join, and so should error if no This is different from them asking to join, and so should error if no
@ -689,7 +755,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {"state": "join", "attestation": local_attestation} return {"state": "join", "attestation": local_attestation}
async def join_group(self, group_id, requester_user_id, content): async def join_group(
self, group_id: str, requester_user_id: str, content: JsonDict
) -> JsonDict:
"""User tries to join the group. """User tries to join the group.
This will error if the group requires an invite/knock to join This will error if the group requires an invite/knock to join
@ -698,6 +766,8 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
group_info = await self.check_group_is_ours( group_info = await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True group_id, requester_user_id, and_exists=True
) )
if not group_info:
raise SynapseError(404, "Group does not exist", errcode=Codes.NOT_FOUND)
if group_info["join_policy"] != "open": if group_info["join_policy"] != "open":
raise SynapseError(403, "Group is not publicly joinable") raise SynapseError(403, "Group is not publicly joinable")
@ -705,25 +775,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {"state": "join", "attestation": local_attestation} return {"state": "join", "attestation": local_attestation}
async def knock(self, group_id, requester_user_id, content):
"""A user requests becoming a member of the group"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
raise NotImplementedError()
async def accept_knock(self, group_id, requester_user_id, content):
"""Accept a users knock to the room.
Errors if the user hasn't knocked, rather than inviting them.
"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
raise NotImplementedError()
async def remove_user_from_group( async def remove_user_from_group(
self, group_id, user_id, requester_user_id, content self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
): ) -> JsonDict:
"""Remove a user from the group; either a user is leaving or an admin """Remove a user from the group; either a user is leaving or an admin
kicked them. kicked them.
""" """
@ -745,6 +799,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
if is_kick: if is_kick:
if self.hs.is_mine_id(user_id): if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler() groups_local = self.hs.get_groups_local_handler()
assert isinstance(
groups_local, GroupsLocalHandler
), "Workers cannot remove users from groups."
await groups_local.user_removed_from_group(group_id, user_id, {}) await groups_local.user_removed_from_group(group_id, user_id, {})
else: else:
await self.transport_client.remove_user_from_group_notification( await self.transport_client.remove_user_from_group_notification(
@ -761,14 +818,15 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def create_group(self, group_id, requester_user_id, content): async def create_group(
group = await self.check_group_is_ours(group_id, requester_user_id) self, group_id: str, requester_user_id: str, content: JsonDict
) -> JsonDict:
logger.info("Attempting to create group with ID: %r", group_id) logger.info("Attempting to create group with ID: %r", group_id)
# parsing the id into a GroupID validates it. # parsing the id into a GroupID validates it.
group_id_obj = GroupID.from_string(group_id) group_id_obj = GroupID.from_string(group_id)
group = await self.check_group_is_ours(group_id, requester_user_id)
if group: if group:
raise SynapseError(400, "Group already exists") raise SynapseError(400, "Group already exists")
@ -813,7 +871,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
local_attestation = self.attestations.create_attestation( local_attestation = self.attestations.create_attestation(
group_id, requester_user_id group_id, requester_user_id
) ) # type: Optional[JsonDict]
else: else:
local_attestation = None local_attestation = None
remote_attestation = None remote_attestation = None
@ -836,15 +894,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {"group_id": group_id} return {"group_id": group_id}
async def delete_group(self, group_id, requester_user_id): async def delete_group(self, group_id: str, requester_user_id: str) -> None:
"""Deletes a group, kicking out all current members. """Deletes a group, kicking out all current members.
Only group admins or server admins can call this request Only group admins or server admins can call this request
Args: Args:
group_id (str) group_id: The group ID to delete.
request_user_id (str) requester_user_id: The user requesting to delete the group.
""" """
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
@ -867,6 +924,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
async def _kick_user_from_group(user_id): async def _kick_user_from_group(user_id):
if self.hs.is_mine_id(user_id): if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler() groups_local = self.hs.get_groups_local_handler()
assert isinstance(
groups_local, GroupsLocalHandler
), "Workers cannot kick users from groups."
await groups_local.user_removed_from_group(group_id, user_id, {}) await groups_local.user_removed_from_group(group_id, user_id, {})
else: else:
await self.transport_client.remove_user_from_group_notification( await self.transport_client.remove_user_from_group_notification(
@ -898,7 +958,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
await self.store.delete_group(group_id) await self.store.delete_group(group_id)
def _parse_join_policy_from_contents(content): def _parse_join_policy_from_contents(content: JsonDict) -> Optional[str]:
"""Given a content for a request, return the specified join policy or None""" """Given a content for a request, return the specified join policy or None"""
join_policy_dict = content.get("m.join_policy") join_policy_dict = content.get("m.join_policy")
@ -908,7 +968,7 @@ def _parse_join_policy_from_contents(content):
return None return None
def _parse_join_policy_dict(join_policy_dict): def _parse_join_policy_dict(join_policy_dict: JsonDict) -> str:
"""Given a dict for the "m.join_policy" config return the join policy specified""" """Given a dict for the "m.join_policy" config return the join policy specified"""
join_policy_type = join_policy_dict.get("type") join_policy_type = join_policy_dict.get("type")
if not join_policy_type: if not join_policy_type:
@ -919,7 +979,7 @@ def _parse_join_policy_dict(join_policy_dict):
return join_policy_type return join_policy_type
def _parse_visibility_from_contents(content): def _parse_visibility_from_contents(content: JsonDict) -> bool:
"""Given a content for a request parse out whether the entity should be """Given a content for a request parse out whether the entity should be
public or not public or not
""" """
@ -933,7 +993,7 @@ def _parse_visibility_from_contents(content):
return is_public return is_public
def _parse_visibility_dict(visibility): def _parse_visibility_dict(visibility: JsonDict) -> bool:
"""Given a dict for the "m.visibility" config return if the entity should """Given a dict for the "m.visibility" config return if the entity should
be public or not be public or not
""" """

View file

@ -16,11 +16,16 @@
import logging import logging
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Optional, Tuple
from twisted.web.http import Request from twisted.web.http import Request
from synapse.api.errors import SynapseError from synapse.api.constants import (
MAX_GROUP_CATEGORYID_LENGTH,
MAX_GROUP_ROLEID_LENGTH,
MAX_GROUPID_LENGTH,
)
from synapse.api.errors import Codes, SynapseError
from synapse.handlers.groups_local import GroupsLocalHandler from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -84,7 +89,9 @@ class GroupServlet(RestServlet):
assert_params_in_dict( assert_params_in_dict(
content, ("name", "avatar_url", "short_description", "long_description") content, ("name", "avatar_url", "short_description", "long_description")
) )
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot create group profiles."
await self.groups_handler.update_group_profile( await self.groups_handler.update_group_profile(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -137,13 +144,26 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_PUT( async def on_PUT(
self, request: Request, group_id: str, category_id: str, room_id: str self, request: Request, group_id: str, category_id: Optional[str], room_id: str
): ):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
if category_id == "":
raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM)
if category_id and len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
raise SynapseError(
400,
"category_id may not be longer than %s characters"
% (MAX_GROUP_CATEGORYID_LENGTH,),
Codes.INVALID_PARAM,
)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot modify group summaries."
resp = await self.groups_handler.update_group_summary_room( resp = await self.groups_handler.update_group_summary_room(
group_id, group_id,
requester_user_id, requester_user_id,
@ -161,7 +181,9 @@ class GroupSummaryRoomsCatServlet(RestServlet):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot modify group profiles."
resp = await self.groups_handler.delete_group_summary_room( resp = await self.groups_handler.delete_group_summary_room(
group_id, requester_user_id, room_id=room_id, category_id=category_id group_id, requester_user_id, room_id=room_id, category_id=category_id
) )
@ -202,8 +224,21 @@ class GroupCategoryServlet(RestServlet):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
if not category_id:
raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM)
if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
raise SynapseError(
400,
"category_id may not be longer than %s characters"
% (MAX_GROUP_CATEGORYID_LENGTH,),
Codes.INVALID_PARAM,
)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot modify group categories."
resp = await self.groups_handler.update_group_category( resp = await self.groups_handler.update_group_category(
group_id, requester_user_id, category_id=category_id, content=content group_id, requester_user_id, category_id=category_id, content=content
) )
@ -217,7 +252,9 @@ class GroupCategoryServlet(RestServlet):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot modify group categories."
resp = await self.groups_handler.delete_group_category( resp = await self.groups_handler.delete_group_category(
group_id, requester_user_id, category_id=category_id group_id, requester_user_id, category_id=category_id
) )
@ -279,8 +316,21 @@ class GroupRoleServlet(RestServlet):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
if not role_id:
raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM)
if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
raise SynapseError(
400,
"role_id may not be longer than %s characters"
% (MAX_GROUP_ROLEID_LENGTH,),
Codes.INVALID_PARAM,
)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot modify group roles."
resp = await self.groups_handler.update_group_role( resp = await self.groups_handler.update_group_role(
group_id, requester_user_id, role_id=role_id, content=content group_id, requester_user_id, role_id=role_id, content=content
) )
@ -294,7 +344,9 @@ class GroupRoleServlet(RestServlet):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot modify group roles."
resp = await self.groups_handler.delete_group_role( resp = await self.groups_handler.delete_group_role(
group_id, requester_user_id, role_id=role_id group_id, requester_user_id, role_id=role_id
) )
@ -347,13 +399,26 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_PUT( async def on_PUT(
self, request: Request, group_id: str, role_id: str, user_id: str self, request: Request, group_id: str, role_id: Optional[str], user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
if role_id == "":
raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM)
if role_id and len(role_id) > MAX_GROUP_ROLEID_LENGTH:
raise SynapseError(
400,
"role_id may not be longer than %s characters"
% (MAX_GROUP_ROLEID_LENGTH,),
Codes.INVALID_PARAM,
)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot modify group summaries."
resp = await self.groups_handler.update_group_summary_user( resp = await self.groups_handler.update_group_summary_user(
group_id, group_id,
requester_user_id, requester_user_id,
@ -371,7 +436,9 @@ class GroupSummaryUsersRoleServlet(RestServlet):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot modify group summaries."
resp = await self.groups_handler.delete_group_summary_user( resp = await self.groups_handler.delete_group_summary_user(
group_id, requester_user_id, user_id=user_id, role_id=role_id group_id, requester_user_id, user_id=user_id, role_id=role_id
) )
@ -465,7 +532,9 @@ class GroupSettingJoinPolicyServlet(RestServlet):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot modify group join policy."
result = await self.groups_handler.set_group_join_policy( result = await self.groups_handler.set_group_join_policy(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -494,7 +563,19 @@ class GroupCreateServlet(RestServlet):
localpart = content.pop("localpart") localpart = content.pop("localpart")
group_id = GroupID(localpart, self.server_name).to_string() group_id = GroupID(localpart, self.server_name).to_string()
assert isinstance(self.groups_handler, GroupsLocalHandler) if not localpart:
raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
if len(group_id) > MAX_GROUPID_LENGTH:
raise SynapseError(
400,
"Group ID may not be longer than %s characters" % (MAX_GROUPID_LENGTH,),
Codes.INVALID_PARAM,
)
assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot create groups."
result = await self.groups_handler.create_group( result = await self.groups_handler.create_group(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -523,7 +604,9 @@ class GroupAdminRoomsServlet(RestServlet):
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot modify rooms in a group."
result = await self.groups_handler.add_room_to_group( result = await self.groups_handler.add_room_to_group(
group_id, requester_user_id, room_id, content group_id, requester_user_id, room_id, content
) )
@ -537,7 +620,9 @@ class GroupAdminRoomsServlet(RestServlet):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot modify group categories."
result = await self.groups_handler.remove_room_from_group( result = await self.groups_handler.remove_room_from_group(
group_id, requester_user_id, room_id group_id, requester_user_id, room_id
) )
@ -567,7 +652,9 @@ class GroupAdminRoomsConfigServlet(RestServlet):
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot modify group categories."
result = await self.groups_handler.update_room_in_group( result = await self.groups_handler.update_room_in_group(
group_id, requester_user_id, room_id, config_key, content group_id, requester_user_id, room_id, config_key, content
) )
@ -597,7 +684,9 @@ class GroupAdminUsersInviteServlet(RestServlet):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
config = content.get("config", {}) config = content.get("config", {})
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot invite users to a group."
result = await self.groups_handler.invite( result = await self.groups_handler.invite(
group_id, user_id, requester_user_id, config group_id, user_id, requester_user_id, config
) )
@ -624,7 +713,9 @@ class GroupAdminUsersKickServlet(RestServlet):
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot kick users from a group."
result = await self.groups_handler.remove_user_from_group( result = await self.groups_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content group_id, user_id, requester_user_id, content
) )
@ -649,7 +740,9 @@ class GroupSelfLeaveServlet(RestServlet):
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot leave a group for a users."
result = await self.groups_handler.remove_user_from_group( result = await self.groups_handler.remove_user_from_group(
group_id, requester_user_id, requester_user_id, content group_id, requester_user_id, requester_user_id, content
) )
@ -674,7 +767,9 @@ class GroupSelfJoinServlet(RestServlet):
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot join a user to a group."
result = await self.groups_handler.join_group( result = await self.groups_handler.join_group(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -699,7 +794,9 @@ class GroupSelfAcceptInviteServlet(RestServlet):
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler) assert isinstance(
self.groups_handler, GroupsLocalHandler
), "Workers cannot accept an invite to a group."
result = await self.groups_handler.accept_invite( result = await self.groups_handler.accept_invite(
group_id, requester_user_id, content group_id, requester_user_id, content
) )

View file

@ -14,7 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple
from typing_extensions import TypedDict
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
@ -26,6 +28,9 @@ from synapse.util import json_encoder
_DEFAULT_CATEGORY_ID = "" _DEFAULT_CATEGORY_ID = ""
_DEFAULT_ROLE_ID = "" _DEFAULT_ROLE_ID = ""
# A room in a group.
_RoomInGroup = TypedDict("_RoomInGroup", {"room_id": str, "is_public": bool})
class GroupServerWorkerStore(SQLBaseStore): class GroupServerWorkerStore(SQLBaseStore):
async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]: async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
@ -72,7 +77,7 @@ class GroupServerWorkerStore(SQLBaseStore):
async def get_rooms_in_group( async def get_rooms_in_group(
self, group_id: str, include_private: bool = False self, group_id: str, include_private: bool = False
) -> List[Dict[str, Union[str, bool]]]: ) -> List[_RoomInGroup]:
"""Retrieve the rooms that belong to a given group. Does not return rooms that """Retrieve the rooms that belong to a given group. Does not return rooms that
lack members. lack members.