Make cached account data/tags/admin types immutable (#16325)

This commit is contained in:
Patrick Cloke 2023-09-18 09:55:04 -04:00 committed by GitHub
parent 85bfd4735e
commit c1e244c8f7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 55 additions and 50 deletions

1
changelog.d/16325.misc Normal file
View file

@ -0,0 +1 @@
Improve type hints.

View file

@ -17,7 +17,7 @@ import logging
import os import os
import sys import sys
import tempfile import tempfile
from typing import List, Mapping, Optional from typing import List, Mapping, Optional, Sequence
from twisted.internet import defer, task from twisted.internet import defer, task
@ -57,7 +57,7 @@ from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.storage.databases.main.tags import TagsWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore
from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore
from synapse.types import JsonDict, StateMap from synapse.types import JsonMapping, StateMap
from synapse.util import SYNAPSE_VERSION from synapse.util import SYNAPSE_VERSION
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
@ -198,7 +198,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
for event in state.values(): for event in state.values():
json.dump(event, fp=f) json.dump(event, fp=f)
def write_profile(self, profile: JsonDict) -> None: def write_profile(self, profile: JsonMapping) -> None:
user_directory = os.path.join(self.base_directory, "user_data") user_directory = os.path.join(self.base_directory, "user_data")
os.makedirs(user_directory, exist_ok=True) os.makedirs(user_directory, exist_ok=True)
profile_file = os.path.join(user_directory, "profile") profile_file = os.path.join(user_directory, "profile")
@ -206,7 +206,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
with open(profile_file, "a") as f: with open(profile_file, "a") as f:
json.dump(profile, fp=f) json.dump(profile, fp=f)
def write_devices(self, devices: List[JsonDict]) -> None: def write_devices(self, devices: Sequence[JsonMapping]) -> None:
user_directory = os.path.join(self.base_directory, "user_data") user_directory = os.path.join(self.base_directory, "user_data")
os.makedirs(user_directory, exist_ok=True) os.makedirs(user_directory, exist_ok=True)
device_file = os.path.join(user_directory, "devices") device_file = os.path.join(user_directory, "devices")
@ -215,7 +215,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
with open(device_file, "a") as f: with open(device_file, "a") as f:
json.dump(device, fp=f) json.dump(device, fp=f)
def write_connections(self, connections: List[JsonDict]) -> None: def write_connections(self, connections: Sequence[JsonMapping]) -> None:
user_directory = os.path.join(self.base_directory, "user_data") user_directory = os.path.join(self.base_directory, "user_data")
os.makedirs(user_directory, exist_ok=True) os.makedirs(user_directory, exist_ok=True)
connection_file = os.path.join(user_directory, "connections") connection_file = os.path.join(user_directory, "connections")
@ -225,7 +225,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
json.dump(connection, fp=f) json.dump(connection, fp=f)
def write_account_data( def write_account_data(
self, file_name: str, account_data: Mapping[str, JsonDict] self, file_name: str, account_data: Mapping[str, JsonMapping]
) -> None: ) -> None:
account_data_directory = os.path.join( account_data_directory = os.path.join(
self.base_directory, "user_data", "account_data" self.base_directory, "user_data", "account_data"
@ -237,7 +237,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
with open(account_data_file, "a") as f: with open(account_data_file, "a") as f:
json.dump(account_data, fp=f) json.dump(account_data, fp=f)
def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None: def write_media_id(self, media_id: str, media_metadata: JsonMapping) -> None:
file_directory = os.path.join(self.base_directory, "media_ids") file_directory = os.path.join(self.base_directory, "media_ids")
os.makedirs(file_directory, exist_ok=True) os.makedirs(file_directory, exist_ok=True)
media_id_file = os.path.join(file_directory, media_id) media_id_file = os.path.join(file_directory, media_id)

View file

@ -14,11 +14,11 @@
import abc import abc
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set
from synapse.api.constants import Direction, Membership from synapse.api.constants import Direction, Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID, UserInfo
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
if TYPE_CHECKING: if TYPE_CHECKING:
@ -35,7 +35,7 @@ class AdminHandler:
self._state_storage_controller = self._storage_controllers.state self._state_storage_controller = self._storage_controllers.state
self._msc3866_enabled = hs.config.experimental.msc3866.enabled self._msc3866_enabled = hs.config.experimental.msc3866.enabled
async def get_whois(self, user: UserID) -> JsonDict: async def get_whois(self, user: UserID) -> JsonMapping:
connections = [] connections = []
sessions = await self._store.get_user_ip_and_agents(user) sessions = await self._store.get_user_ip_and_agents(user)
@ -55,7 +55,7 @@ class AdminHandler:
return ret return ret
async def get_user(self, user: UserID) -> Optional[JsonDict]: async def get_user(self, user: UserID) -> Optional[JsonMapping]:
"""Function to get user details""" """Function to get user details"""
user_info: Optional[UserInfo] = await self._store.get_user_by_id( user_info: Optional[UserInfo] = await self._store.get_user_by_id(
user.to_string() user.to_string()
@ -344,7 +344,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def write_profile(self, profile: JsonDict) -> None: def write_profile(self, profile: JsonMapping) -> None:
"""Write the profile of a user. """Write the profile of a user.
Args: Args:
@ -353,7 +353,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def write_devices(self, devices: List[JsonDict]) -> None: def write_devices(self, devices: Sequence[JsonMapping]) -> None:
"""Write the devices of a user. """Write the devices of a user.
Args: Args:
@ -362,7 +362,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def write_connections(self, connections: List[JsonDict]) -> None: def write_connections(self, connections: Sequence[JsonMapping]) -> None:
"""Write the connections of a user. """Write the connections of a user.
Args: Args:
@ -372,7 +372,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def write_account_data( def write_account_data(
self, file_name: str, account_data: Mapping[str, JsonDict] self, file_name: str, account_data: Mapping[str, JsonMapping]
) -> None: ) -> None:
"""Write the account data of a user. """Write the account data of a user.
@ -383,7 +383,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None: def write_media_id(self, media_id: str, media_metadata: JsonMapping) -> None:
"""Write the media's metadata of a user. """Write the media's metadata of a user.
Exports only the metadata, as this can be fetched from the database via Exports only the metadata, as this can be fetched from the database via
read only. In order to access the files, a connection to the correct read only. In order to access the files, a connection to the correct

View file

@ -57,6 +57,7 @@ from synapse.storage.roommember import MemberSummary
from synapse.types import ( from synapse.types import (
DeviceListUpdates, DeviceListUpdates,
JsonDict, JsonDict,
JsonMapping,
MutableStateMap, MutableStateMap,
Requester, Requester,
RoomStreamToken, RoomStreamToken,
@ -1793,19 +1794,23 @@ class SyncHandler:
) )
if push_rules_changed: if push_rules_changed:
global_account_data = dict(global_account_data) global_account_data = {
global_account_data[ AccountDataTypes.PUSH_RULES: await self._push_rules_handler.push_rules_for_user(
AccountDataTypes.PUSH_RULES sync_config.user
] = await self._push_rules_handler.push_rules_for_user(sync_config.user) ),
**global_account_data,
}
else: else:
all_global_account_data = await self.store.get_global_account_data_for_user( all_global_account_data = await self.store.get_global_account_data_for_user(
user_id user_id
) )
global_account_data = dict(all_global_account_data) global_account_data = {
global_account_data[ AccountDataTypes.PUSH_RULES: await self._push_rules_handler.push_rules_for_user(
AccountDataTypes.PUSH_RULES sync_config.user
] = await self._push_rules_handler.push_rules_for_user(sync_config.user) ),
**all_global_account_data,
}
account_data_for_user = ( account_data_for_user = (
await sync_config.filter_collection.filter_global_account_data( await sync_config.filter_collection.filter_global_account_data(
@ -1909,7 +1914,7 @@ class SyncHandler:
blocks_all_rooms blocks_all_rooms
or sync_result_builder.sync_config.filter_collection.blocks_all_room_account_data() or sync_result_builder.sync_config.filter_collection.blocks_all_room_account_data()
): ):
account_data_by_room: Mapping[str, Mapping[str, JsonDict]] = {} account_data_by_room: Mapping[str, Mapping[str, JsonMapping]] = {}
elif since_token and not sync_result_builder.full_state: elif since_token and not sync_result_builder.full_state:
account_data_by_room = ( account_data_by_room = (
await self.store.get_updated_room_account_data_for_user( await self.store.get_updated_room_account_data_for_user(
@ -2349,8 +2354,8 @@ class SyncHandler:
sync_result_builder: "SyncResultBuilder", sync_result_builder: "SyncResultBuilder",
room_builder: "RoomSyncResultBuilder", room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict], ephemeral: List[JsonDict],
tags: Optional[Mapping[str, Mapping[str, Any]]], tags: Optional[Mapping[str, JsonMapping]],
account_data: Mapping[str, JsonDict], account_data: Mapping[str, JsonMapping],
always_include: bool = False, always_include: bool = False,
) -> None: ) -> None:
"""Populates the `joined` and `archived` section of `sync_result_builder` """Populates the `joined` and `archived` section of `sync_result_builder`

View file

@ -39,7 +39,7 @@ from synapse.rest.admin._base import (
from synapse.rest.client._base import client_patterns from synapse.rest.client._base import client_patterns
from synapse.storage.databases.main.registration import ExternalIDReuseException from synapse.storage.databases.main.registration import ExternalIDReuseException
from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.databases.main.stats import UserSortOrder
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, JsonMapping, UserID
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -211,7 +211,7 @@ class UserRestServletV2(RestServlet):
async def on_GET( async def on_GET(
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonMapping]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -226,7 +226,7 @@ class UserRestServletV2(RestServlet):
async def on_PUT( async def on_PUT(
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonMapping]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester) await assert_user_is_admin(self.auth, requester)
@ -658,7 +658,7 @@ class WhoisRestServlet(RestServlet):
async def on_GET( async def on_GET(
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonMapping]:
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)

View file

@ -20,7 +20,7 @@ from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, RoomID from synapse.types import JsonDict, JsonMapping, RoomID
from ._base import client_patterns from ._base import client_patterns
@ -95,7 +95,7 @@ class AccountDataServlet(RestServlet):
async def on_GET( async def on_GET(
self, request: SynapseRequest, user_id: str, account_data_type: str self, request: SynapseRequest, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonMapping]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.") raise AuthError(403, "Cannot get account data for other users.")
@ -106,7 +106,7 @@ class AccountDataServlet(RestServlet):
and account_data_type == AccountDataTypes.PUSH_RULES and account_data_type == AccountDataTypes.PUSH_RULES
): ):
account_data: Optional[ account_data: Optional[
JsonDict JsonMapping
] = await self._push_rules_handler.push_rules_for_user(requester.user) ] = await self._push_rules_handler.push_rules_for_user(requester.user)
else: else:
account_data = await self.store.get_global_account_data_by_type_for_user( account_data = await self.store.get_global_account_data_by_type_for_user(
@ -236,7 +236,7 @@ class RoomAccountDataServlet(RestServlet):
user_id: str, user_id: str,
room_id: str, room_id: str,
account_data_type: str, account_data_type: str,
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonMapping]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.") raise AuthError(403, "Cannot get account data for other users.")
@ -253,7 +253,7 @@ class RoomAccountDataServlet(RestServlet):
self._hs.config.experimental.msc4010_push_rules_account_data self._hs.config.experimental.msc4010_push_rules_account_data
and account_data_type == AccountDataTypes.PUSH_RULES and account_data_type == AccountDataTypes.PUSH_RULES
): ):
account_data: Optional[JsonDict] = {} account_data: Optional[JsonMapping] = {}
else: else:
account_data = await self.store.get_account_data_for_room_and_type( account_data = await self.store.get_account_data_for_room_and_type(
user_id, room_id, account_data_type user_id, room_id, account_data_type

View file

@ -43,7 +43,7 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator, MultiWriterIdGenerator,
StreamIdGenerator, StreamIdGenerator,
) )
from synapse.types import JsonDict from synapse.types import JsonDict, JsonMapping
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -119,7 +119,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
@cached() @cached()
async def get_global_account_data_for_user( async def get_global_account_data_for_user(
self, user_id: str self, user_id: str
) -> Mapping[str, JsonDict]: ) -> Mapping[str, JsonMapping]:
""" """
Get all the global client account_data for a user. Get all the global client account_data for a user.
@ -164,7 +164,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
@cached() @cached()
async def get_room_account_data_for_user( async def get_room_account_data_for_user(
self, user_id: str self, user_id: str
) -> Mapping[str, Mapping[str, JsonDict]]: ) -> Mapping[str, Mapping[str, JsonMapping]]:
""" """
Get all of the per-room client account_data for a user. Get all of the per-room client account_data for a user.
@ -213,7 +213,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
@cached(num_args=2, max_entries=5000, tree=True) @cached(num_args=2, max_entries=5000, tree=True)
async def get_global_account_data_by_type_for_user( async def get_global_account_data_by_type_for_user(
self, user_id: str, data_type: str self, user_id: str, data_type: str
) -> Optional[JsonDict]: ) -> Optional[JsonMapping]:
""" """
Returns: Returns:
The account data. The account data.
@ -265,7 +265,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
@cached(num_args=2, tree=True) @cached(num_args=2, tree=True)
async def get_account_data_for_room( async def get_account_data_for_room(
self, user_id: str, room_id: str self, user_id: str, room_id: str
) -> Mapping[str, JsonDict]: ) -> Mapping[str, JsonMapping]:
"""Get all the client account_data for a user for a room. """Get all the client account_data for a user for a room.
Args: Args:
@ -296,7 +296,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
@cached(num_args=3, max_entries=5000, tree=True) @cached(num_args=3, max_entries=5000, tree=True)
async def get_account_data_for_room_and_type( async def get_account_data_for_room_and_type(
self, user_id: str, room_id: str, account_data_type: str self, user_id: str, room_id: str, account_data_type: str
) -> Optional[JsonDict]: ) -> Optional[JsonMapping]:
"""Get the client account_data of given type for a user for a room. """Get the client account_data of given type for a user for a room.
Args: Args:
@ -394,7 +394,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
async def get_updated_global_account_data_for_user( async def get_updated_global_account_data_for_user(
self, user_id: str, stream_id: int self, user_id: str, stream_id: int
) -> Dict[str, JsonDict]: ) -> Mapping[str, JsonMapping]:
"""Get all the global account_data that's changed for a user. """Get all the global account_data that's changed for a user.
Args: Args:

View file

@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict, FrozenSet
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.types import StrCollection
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
if TYPE_CHECKING: if TYPE_CHECKING:
@ -34,7 +33,7 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
@cached() @cached()
async def list_enabled_features(self, user_id: str) -> StrCollection: async def list_enabled_features(self, user_id: str) -> FrozenSet[str]:
""" """
Checks to see what features are enabled for a given user Checks to see what features are enabled for a given user
Args: Args:
@ -49,7 +48,7 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
["feature"], ["feature"],
) )
return [feature["feature"] for feature in enabled] return frozenset(feature["feature"] for feature in enabled)
async def set_features_for_user( async def set_features_for_user(
self, self,

View file

@ -23,7 +23,7 @@ from synapse.storage._base import db_to_json
from synapse.storage.database import LoggingTransaction from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.types import JsonDict from synapse.types import JsonDict, JsonMapping
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -34,7 +34,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
@cached() @cached()
async def get_tags_for_user( async def get_tags_for_user(
self, user_id: str self, user_id: str
) -> Mapping[str, Mapping[str, JsonDict]]: ) -> Mapping[str, Mapping[str, JsonMapping]]:
"""Get all the tags for a user. """Get all the tags for a user.
@ -109,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_updated_tags( async def get_updated_tags(
self, user_id: str, stream_id: int self, user_id: str, stream_id: int
) -> Mapping[str, Mapping[str, JsonDict]]: ) -> Mapping[str, Mapping[str, JsonMapping]]:
"""Get all the tags for the rooms where the tags have changed since the """Get all the tags for the rooms where the tags have changed since the
given version given version