diff --git a/changelog.d/16611.misc b/changelog.d/16611.misc new file mode 100644 index 0000000000..93ceaeafc9 --- /dev/null +++ b/changelog.d/16611.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index c2109036ec..1027fbfd28 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import random -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union from synapse.api.errors import ( AuthError, @@ -23,6 +23,7 @@ from synapse.api.errors import ( StoreError, SynapseError, ) +from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util.caches.descriptors import cached from synapse.util.stringutils import parse_and_validate_mxc_uri @@ -306,7 +307,9 @@ class ProfileHandler: server_name = host if self._is_mine_server_name(server_name): - media_info = await self.store.get_local_media(media_id) + media_info: Optional[ + Union[LocalMedia, RemoteMedia] + ] = await self.store.get_local_media(media_id) else: media_info = await self.store.get_cached_remote_media(server_name, media_id) @@ -322,12 +325,12 @@ class ProfileHandler: if self.max_avatar_size: # Ensure avatar does not exceed max allowed avatar size - if media_info["media_length"] > self.max_avatar_size: + if media_info.media_length > self.max_avatar_size: logger.warning( "Forbidding avatar change to %s: %d bytes is above the allowed size " "limit", mxc, - media_info["media_length"], + media_info.media_length, ) return False @@ -335,12 +338,12 @@ class ProfileHandler: # Ensure the avatar's file type is allowed if ( self.allowed_avatar_mimetypes - and media_info["media_type"] not in self.allowed_avatar_mimetypes + and media_info.media_type not in self.allowed_avatar_mimetypes ): logger.warning( "Forbidding avatar change to %s: mimetype %s not allowed", mxc, - media_info["media_type"], + media_info.media_type, ) return False diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 62f2454f5d..389dc5298a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -806,7 +806,7 @@ class SsoHandler: media_id = profile["avatar_url"].split("/")[-1] if self._is_mine_server_name(server_name): media = await self._media_repo.store.get_local_media(media_id) - if media is not None and upload_name == media["upload_name"]: + if media is not None and upload_name == media.upload_name: logger.info("skipping saving the user avatar") return True diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 72b0f1c5de..1957426c6a 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -19,6 +19,7 @@ import shutil from io import BytesIO from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple +import attr from matrix_common.types.mxc_uri import MXCUri import twisted.internet.error @@ -50,6 +51,7 @@ from synapse.media.storage_provider import StorageProviderWrapper from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.url_previewer import UrlPreviewer from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases.main.media_repository import RemoteMedia from synapse.types import UserID from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination @@ -245,18 +247,18 @@ class MediaRepository: Resolves once a response has successfully been written to request """ media_info = await self.store.get_local_media(media_id) - if not media_info or media_info["quarantined_by"]: + if not media_info or media_info.quarantined_by: respond_404(request) return self.mark_recently_accessed(None, media_id) - media_type = media_info["media_type"] + media_type = media_info.media_type if not media_type: media_type = "application/octet-stream" - media_length = media_info["media_length"] - upload_name = name if name else media_info["upload_name"] - url_cache = media_info["url_cache"] + media_length = media_info.media_length + upload_name = name if name else media_info.upload_name + url_cache = media_info.url_cache file_info = FileInfo(None, media_id, url_cache=bool(url_cache)) @@ -310,16 +312,20 @@ class MediaRepository: # We deliberately stream the file outside the lock if responder: - media_type = media_info["media_type"] - media_length = media_info["media_length"] - upload_name = name if name else media_info["upload_name"] + upload_name = name if name else media_info.upload_name await respond_with_responder( - request, responder, media_type, media_length, upload_name + request, + responder, + media_info.media_type, + media_info.media_length, + upload_name, ) else: respond_404(request) - async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: + async def get_remote_media_info( + self, server_name: str, media_id: str + ) -> RemoteMedia: """Gets the media info associated with the remote file, downloading if necessary. @@ -353,7 +359,7 @@ class MediaRepository: async def _get_remote_media_impl( self, server_name: str, media_id: str - ) -> Tuple[Optional[Responder], dict]: + ) -> Tuple[Optional[Responder], RemoteMedia]: """Looks for media in local cache, if not there then attempt to download from remote server. @@ -373,15 +379,17 @@ class MediaRepository: # If we have an entry in the DB, try and look for it if media_info: - file_id = media_info["filesystem_id"] + file_id = media_info.filesystem_id file_info = FileInfo(server_name, file_id) - if media_info["quarantined_by"]: + if media_info.quarantined_by: logger.info("Media is quarantined") raise NotFoundError() - if not media_info["media_type"]: - media_info["media_type"] = "application/octet-stream" + if not media_info.media_type: + media_info = attr.evolve( + media_info, media_type="application/octet-stream" + ) responder = await self.media_storage.fetch_media(file_info) if responder: @@ -403,9 +411,9 @@ class MediaRepository: if not media_info: raise e - file_id = media_info["filesystem_id"] - if not media_info["media_type"]: - media_info["media_type"] = "application/octet-stream" + file_id = media_info.filesystem_id + if not media_info.media_type: + media_info = attr.evolve(media_info, media_type="application/octet-stream") file_info = FileInfo(server_name, file_id) # We generate thumbnails even if another process downloaded the media @@ -415,7 +423,7 @@ class MediaRepository: # otherwise they'll request thumbnails and get a 404 if they're not # ready yet. await self._generate_thumbnails( - server_name, media_id, file_id, media_info["media_type"] + server_name, media_id, file_id, media_info.media_type ) responder = await self.media_storage.fetch_media(file_info) @@ -425,7 +433,7 @@ class MediaRepository: self, server_name: str, media_id: str, - ) -> dict: + ) -> RemoteMedia: """Attempt to download the remote file from the given server name, using the given file_id as the local id. @@ -518,7 +526,7 @@ class MediaRepository: origin=server_name, media_id=media_id, media_type=media_type, - time_now_ms=self.clock.time_msec(), + time_now_ms=time_now_ms, upload_name=upload_name, media_length=length, filesystem_id=file_id, @@ -526,15 +534,17 @@ class MediaRepository: logger.info("Stored remote media in file %r", fname) - media_info = { - "media_type": media_type, - "media_length": length, - "upload_name": upload_name, - "created_ts": time_now_ms, - "filesystem_id": file_id, - } - - return media_info + return RemoteMedia( + media_origin=server_name, + media_id=media_id, + media_type=media_type, + media_length=length, + upload_name=upload_name, + created_ts=time_now_ms, + filesystem_id=file_id, + last_access_ts=time_now_ms, + quarantined_by=None, + ) def _get_thumbnail_requirements( self, media_type: str diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py index 9b5a3dd5f4..44aac21de6 100644 --- a/synapse/media/url_previewer.py +++ b/synapse/media/url_previewer.py @@ -240,15 +240,14 @@ class UrlPreviewer: cache_result = await self.store.get_url_cache(url, ts) if ( cache_result - and cache_result["expires_ts"] > ts - and cache_result["response_code"] / 100 == 2 + and cache_result.expires_ts > ts + and cache_result.response_code // 100 == 2 ): # It may be stored as text in the database, not as bytes (such as # PostgreSQL). If so, encode it back before handing it on. - og = cache_result["og"] - if isinstance(og, str): - og = og.encode("utf8") - return og + if isinstance(cache_result.og, str): + return cache_result.og.encode("utf8") + return cache_result.og # If this URL can be accessed via an allowed oEmbed, use that instead. url_to_download = url diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py index 85b6bdbe72..efda8b4ab4 100644 --- a/synapse/rest/media/thumbnail_resource.py +++ b/synapse/rest/media/thumbnail_resource.py @@ -119,7 +119,7 @@ class ThumbnailResource(RestServlet): if not media_info: respond_404(request) return - if media_info["quarantined_by"]: + if media_info.quarantined_by: logger.info("Media is quarantined") respond_404(request) return @@ -134,7 +134,7 @@ class ThumbnailResource(RestServlet): thumbnail_infos, media_id, media_id, - url_cache=bool(media_info["url_cache"]), + url_cache=bool(media_info.url_cache), server_name=None, ) @@ -152,7 +152,7 @@ class ThumbnailResource(RestServlet): if not media_info: respond_404(request) return - if media_info["quarantined_by"]: + if media_info.quarantined_by: logger.info("Media is quarantined") respond_404(request) return @@ -168,7 +168,7 @@ class ThumbnailResource(RestServlet): file_info = FileInfo( server_name=None, file_id=media_id, - url_cache=media_info["url_cache"], + url_cache=bool(media_info.url_cache), thumbnail=info, ) @@ -188,7 +188,7 @@ class ThumbnailResource(RestServlet): desired_height, desired_method, desired_type, - url_cache=bool(media_info["url_cache"]), + url_cache=bool(media_info.url_cache), ) if file_path: @@ -213,7 +213,7 @@ class ThumbnailResource(RestServlet): server_name, media_id ) - file_id = media_info["filesystem_id"] + file_id = media_info.filesystem_id for info in thumbnail_infos: t_w = info.width == desired_width @@ -224,7 +224,7 @@ class ThumbnailResource(RestServlet): if t_w and t_h and t_method and t_type: file_info = FileInfo( server_name=server_name, - file_id=media_info["filesystem_id"], + file_id=file_id, thumbnail=info, ) @@ -280,7 +280,7 @@ class ThumbnailResource(RestServlet): m_type, thumbnail_infos, media_id, - media_info["filesystem_id"], + media_info.filesystem_id, url_cache=False, server_name=server_name, ) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index c8d7c9fd32..7f99c64f1b 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -15,9 +15,7 @@ from enum import Enum from typing import ( TYPE_CHECKING, - Any, Collection, - Dict, Iterable, List, Optional, @@ -54,11 +52,32 @@ class LocalMedia: media_length: int upload_name: str created_ts: int + url_cache: Optional[str] last_access_ts: int quarantined_by: Optional[str] safe_from_quarantine: bool +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RemoteMedia: + media_origin: str + media_id: str + media_type: str + media_length: int + upload_name: Optional[str] + filesystem_id: str + created_ts: int + last_access_ts: int + quarantined_by: Optional[str] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class UrlCache: + response_code: int + expires_ts: int + og: Union[str, bytes] + + class MediaSortOrder(Enum): """ Enum to define the sorting method used when returning media with @@ -165,13 +184,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): super().__init__(database, db_conn, hs) self.server_name: str = hs.hostname - async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: + async def get_local_media(self, media_id: str) -> Optional[LocalMedia]: """Get the metadata for a local piece of media Returns: None if the media_id doesn't exist. """ - return await self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -181,11 +200,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "created_ts", "quarantined_by", "url_cache", + "last_access_ts", "safe_from_quarantine", ), allow_none=True, desc="get_local_media", ) + if row is None: + return None + return LocalMedia(media_id=media_id, **row) async def get_local_media_by_user_paginate( self, @@ -236,6 +259,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): media_length, upload_name, created_ts, + url_cache, last_access_ts, quarantined_by, safe_from_quarantine @@ -257,9 +281,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): media_length=row[2], upload_name=row[3], created_ts=row[4], - last_access_ts=row[5], - quarantined_by=row[6], - safe_from_quarantine=bool(row[7]), + url_cache=row[5], + last_access_ts=row[6], + quarantined_by=row[7], + safe_from_quarantine=bool(row[8]), ) for row in txn ] @@ -390,51 +415,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="mark_local_media_as_safe", ) - async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]: + async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]: """Get the media_id and ts for a cached URL as of the given timestamp Returns: None if the URL isn't cached. """ - def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: + def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]: # get the most recently cached result (relative to the given ts) - sql = ( - "SELECT response_code, etag, expires_ts, og, media_id, download_ts" - " FROM local_media_repository_url_cache" - " WHERE url = ? AND download_ts <= ?" - " ORDER BY download_ts DESC LIMIT 1" - ) + sql = """ + SELECT response_code, expires_ts, og + FROM local_media_repository_url_cache + WHERE url = ? AND download_ts <= ? + ORDER BY download_ts DESC LIMIT 1 + """ txn.execute(sql, (url, ts)) row = txn.fetchone() if not row: # ...or if we've requested a timestamp older than the oldest # copy in the cache, return the oldest copy (if any) - sql = ( - "SELECT response_code, etag, expires_ts, og, media_id, download_ts" - " FROM local_media_repository_url_cache" - " WHERE url = ? AND download_ts > ?" - " ORDER BY download_ts ASC LIMIT 1" - ) + sql = """ + SELECT response_code, expires_ts, og + FROM local_media_repository_url_cache + WHERE url = ? AND download_ts > ? + ORDER BY download_ts ASC LIMIT 1 + """ txn.execute(sql, (url, ts)) row = txn.fetchone() if not row: return None - return dict( - zip( - ( - "response_code", - "etag", - "expires_ts", - "og", - "media_id", - "download_ts", - ), - row, - ) - ) + return UrlCache(response_code=row[0], expires_ts=row[1], og=row[2]) return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) @@ -444,7 +457,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): response_code: int, etag: Optional[str], expires_ts: int, - og: Optional[str], + og: str, media_id: str, download_ts: int, ) -> None: @@ -510,8 +523,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def get_cached_remote_media( self, origin: str, media_id: str - ) -> Optional[Dict[str, Any]]: - return await self.db_pool.simple_select_one( + ) -> Optional[RemoteMedia]: + row = await self.db_pool.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( @@ -520,11 +533,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "upload_name", "created_ts", "filesystem_id", + "last_access_ts", "quarantined_by", ), allow_none=True, desc="get_cached_remote_media", ) + if row is None: + return row + return RemoteMedia(media_origin=origin, media_id=media_id, **row) async def store_cached_remote_media( self, @@ -623,10 +640,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): t_width: int, t_height: int, t_type: str, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[ThumbnailInfo]: """Fetch the thumbnail info of given width, height and type.""" - return await self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( table="remote_media_cache_thumbnails", keyvalues={ "media_origin": origin, @@ -641,11 +658,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "thumbnail_method", "thumbnail_type", "thumbnail_length", - "filesystem_id", ), allow_none=True, desc="get_remote_media_thumbnail", ) + if row is None: + return None + return ThumbnailInfo( + width=row["thumbnail_width"], + height=row["thumbnail_height"], + method=row["thumbnail_method"], + type=row["thumbnail_type"], + length=row["thumbnail_length"], + ) @trace async def store_remote_media_thumbnail( diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index 15f5d644e4..a8e7a76b29 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -504,7 +504,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): origin, media_id = self.media_id.split("/") info = self.get_success(self.store.get_cached_remote_media(origin, media_id)) assert info is not None - file_id = info["filesystem_id"] + file_id = info.filesystem_id thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir( origin, file_id diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 278808abb5..dac79bd745 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -642,7 +642,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertFalse(media_info["quarantined_by"]) + self.assertFalse(media_info.quarantined_by) # quarantining channel = self.make_request( @@ -656,7 +656,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertTrue(media_info["quarantined_by"]) + self.assertTrue(media_info.quarantined_by) # remove from quarantine channel = self.make_request( @@ -670,7 +670,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertFalse(media_info["quarantined_by"]) + self.assertFalse(media_info.quarantined_by) def test_quarantine_protected_media(self) -> None: """ @@ -683,7 +683,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests): # verify protection media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertTrue(media_info["safe_from_quarantine"]) + self.assertTrue(media_info.safe_from_quarantine) # quarantining channel = self.make_request( @@ -698,7 +698,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests): # verify that is not in quarantine media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertFalse(media_info["quarantined_by"]) + self.assertFalse(media_info.quarantined_by) class ProtectMediaByIDTestCase(_AdminMediaTests): @@ -756,7 +756,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertFalse(media_info["safe_from_quarantine"]) + self.assertFalse(media_info.safe_from_quarantine) # protect channel = self.make_request( @@ -770,7 +770,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertTrue(media_info["safe_from_quarantine"]) + self.assertTrue(media_info.safe_from_quarantine) # unprotect channel = self.make_request( @@ -784,7 +784,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertFalse(media_info["safe_from_quarantine"]) + self.assertFalse(media_info.safe_from_quarantine) class PurgeMediaCacheTestCase(_AdminMediaTests): diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py index b59d9dfd4d..27a663a23b 100644 --- a/tests/rest/media/test_media_retention.py +++ b/tests/rest/media/test_media_retention.py @@ -267,23 +267,23 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None: """Given an MXC URI, assert whether it has been purged or not.""" if mxc_uri.server_name == self.hs.config.server.server_name: - found_media_dict = self.get_success( - self.store.get_local_media(mxc_uri.media_id) + found_media = bool( + self.get_success(self.store.get_local_media(mxc_uri.media_id)) ) else: - found_media_dict = self.get_success( - self.store.get_cached_remote_media( - mxc_uri.server_name, mxc_uri.media_id + found_media = bool( + self.get_success( + self.store.get_cached_remote_media( + mxc_uri.server_name, mxc_uri.media_id + ) ) ) if expect_purged: - self.assertIsNone( - found_media_dict, msg=f"{mxc_uri} unexpectedly not purged" - ) + self.assertFalse(found_media, msg=f"{mxc_uri} unexpectedly not purged") else: - self.assertIsNotNone( - found_media_dict, + self.assertTrue( + found_media, msg=f"{mxc_uri} unexpectedly purged", )