Return attrs for more media repo APIs. (#16611)

This commit is contained in:
Patrick Cloke 2023-11-09 11:00:30 -05:00 committed by GitHub
parent 91587d4cf9
commit ff716b483b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 148 additions and 110 deletions

1
changelog.d/16611.misc Normal file
View file

@ -0,0 +1 @@
Improve type hints.

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import random import random
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional, Union
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
@ -23,6 +23,7 @@ from synapse.api.errors import (
StoreError, StoreError,
SynapseError, SynapseError,
) )
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import parse_and_validate_mxc_uri from synapse.util.stringutils import parse_and_validate_mxc_uri
@ -306,7 +307,9 @@ class ProfileHandler:
server_name = host server_name = host
if self._is_mine_server_name(server_name): if self._is_mine_server_name(server_name):
media_info = await self.store.get_local_media(media_id) media_info: Optional[
Union[LocalMedia, RemoteMedia]
] = await self.store.get_local_media(media_id)
else: else:
media_info = await self.store.get_cached_remote_media(server_name, media_id) media_info = await self.store.get_cached_remote_media(server_name, media_id)
@ -322,12 +325,12 @@ class ProfileHandler:
if self.max_avatar_size: if self.max_avatar_size:
# Ensure avatar does not exceed max allowed avatar size # Ensure avatar does not exceed max allowed avatar size
if media_info["media_length"] > self.max_avatar_size: if media_info.media_length > self.max_avatar_size:
logger.warning( logger.warning(
"Forbidding avatar change to %s: %d bytes is above the allowed size " "Forbidding avatar change to %s: %d bytes is above the allowed size "
"limit", "limit",
mxc, mxc,
media_info["media_length"], media_info.media_length,
) )
return False return False
@ -335,12 +338,12 @@ class ProfileHandler:
# Ensure the avatar's file type is allowed # Ensure the avatar's file type is allowed
if ( if (
self.allowed_avatar_mimetypes self.allowed_avatar_mimetypes
and media_info["media_type"] not in self.allowed_avatar_mimetypes and media_info.media_type not in self.allowed_avatar_mimetypes
): ):
logger.warning( logger.warning(
"Forbidding avatar change to %s: mimetype %s not allowed", "Forbidding avatar change to %s: mimetype %s not allowed",
mxc, mxc,
media_info["media_type"], media_info.media_type,
) )
return False return False

View file

@ -806,7 +806,7 @@ class SsoHandler:
media_id = profile["avatar_url"].split("/")[-1] media_id = profile["avatar_url"].split("/")[-1]
if self._is_mine_server_name(server_name): if self._is_mine_server_name(server_name):
media = await self._media_repo.store.get_local_media(media_id) media = await self._media_repo.store.get_local_media(media_id)
if media is not None and upload_name == media["upload_name"]: if media is not None and upload_name == media.upload_name:
logger.info("skipping saving the user avatar") logger.info("skipping saving the user avatar")
return True return True

View file

@ -19,6 +19,7 @@ import shutil
from io import BytesIO from io import BytesIO
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import attr
from matrix_common.types.mxc_uri import MXCUri from matrix_common.types.mxc_uri import MXCUri
import twisted.internet.error import twisted.internet.error
@ -50,6 +51,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
from synapse.media.url_previewer import UrlPreviewer from synapse.media.url_previewer import UrlPreviewer
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main.media_repository import RemoteMedia
from synapse.types import UserID from synapse.types import UserID
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
@ -245,18 +247,18 @@ class MediaRepository:
Resolves once a response has successfully been written to request Resolves once a response has successfully been written to request
""" """
media_info = await self.store.get_local_media(media_id) media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]: if not media_info or media_info.quarantined_by:
respond_404(request) respond_404(request)
return return
self.mark_recently_accessed(None, media_id) self.mark_recently_accessed(None, media_id)
media_type = media_info["media_type"] media_type = media_info.media_type
if not media_type: if not media_type:
media_type = "application/octet-stream" media_type = "application/octet-stream"
media_length = media_info["media_length"] media_length = media_info.media_length
upload_name = name if name else media_info["upload_name"] upload_name = name if name else media_info.upload_name
url_cache = media_info["url_cache"] url_cache = media_info.url_cache
file_info = FileInfo(None, media_id, url_cache=bool(url_cache)) file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
@ -310,16 +312,20 @@ class MediaRepository:
# We deliberately stream the file outside the lock # We deliberately stream the file outside the lock
if responder: if responder:
media_type = media_info["media_type"] upload_name = name if name else media_info.upload_name
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
await respond_with_responder( await respond_with_responder(
request, responder, media_type, media_length, upload_name request,
responder,
media_info.media_type,
media_info.media_length,
upload_name,
) )
else: else:
respond_404(request) respond_404(request)
async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: async def get_remote_media_info(
self, server_name: str, media_id: str
) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading """Gets the media info associated with the remote file, downloading
if necessary. if necessary.
@ -353,7 +359,7 @@ class MediaRepository:
async def _get_remote_media_impl( async def _get_remote_media_impl(
self, server_name: str, media_id: str self, server_name: str, media_id: str
) -> Tuple[Optional[Responder], dict]: ) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to """Looks for media in local cache, if not there then attempt to
download from remote server. download from remote server.
@ -373,15 +379,17 @@ class MediaRepository:
# If we have an entry in the DB, try and look for it # If we have an entry in the DB, try and look for it
if media_info: if media_info:
file_id = media_info["filesystem_id"] file_id = media_info.filesystem_id
file_info = FileInfo(server_name, file_id) file_info = FileInfo(server_name, file_id)
if media_info["quarantined_by"]: if media_info.quarantined_by:
logger.info("Media is quarantined") logger.info("Media is quarantined")
raise NotFoundError() raise NotFoundError()
if not media_info["media_type"]: if not media_info.media_type:
media_info["media_type"] = "application/octet-stream" media_info = attr.evolve(
media_info, media_type="application/octet-stream"
)
responder = await self.media_storage.fetch_media(file_info) responder = await self.media_storage.fetch_media(file_info)
if responder: if responder:
@ -403,9 +411,9 @@ class MediaRepository:
if not media_info: if not media_info:
raise e raise e
file_id = media_info["filesystem_id"] file_id = media_info.filesystem_id
if not media_info["media_type"]: if not media_info.media_type:
media_info["media_type"] = "application/octet-stream" media_info = attr.evolve(media_info, media_type="application/octet-stream")
file_info = FileInfo(server_name, file_id) file_info = FileInfo(server_name, file_id)
# We generate thumbnails even if another process downloaded the media # We generate thumbnails even if another process downloaded the media
@ -415,7 +423,7 @@ class MediaRepository:
# otherwise they'll request thumbnails and get a 404 if they're not # otherwise they'll request thumbnails and get a 404 if they're not
# ready yet. # ready yet.
await self._generate_thumbnails( await self._generate_thumbnails(
server_name, media_id, file_id, media_info["media_type"] server_name, media_id, file_id, media_info.media_type
) )
responder = await self.media_storage.fetch_media(file_info) responder = await self.media_storage.fetch_media(file_info)
@ -425,7 +433,7 @@ class MediaRepository:
self, self,
server_name: str, server_name: str,
media_id: str, media_id: str,
) -> dict: ) -> RemoteMedia:
"""Attempt to download the remote file from the given server name, """Attempt to download the remote file from the given server name,
using the given file_id as the local id. using the given file_id as the local id.
@ -518,7 +526,7 @@ class MediaRepository:
origin=server_name, origin=server_name,
media_id=media_id, media_id=media_id,
media_type=media_type, media_type=media_type,
time_now_ms=self.clock.time_msec(), time_now_ms=time_now_ms,
upload_name=upload_name, upload_name=upload_name,
media_length=length, media_length=length,
filesystem_id=file_id, filesystem_id=file_id,
@ -526,15 +534,17 @@ class MediaRepository:
logger.info("Stored remote media in file %r", fname) logger.info("Stored remote media in file %r", fname)
media_info = { return RemoteMedia(
"media_type": media_type, media_origin=server_name,
"media_length": length, media_id=media_id,
"upload_name": upload_name, media_type=media_type,
"created_ts": time_now_ms, media_length=length,
"filesystem_id": file_id, upload_name=upload_name,
} created_ts=time_now_ms,
filesystem_id=file_id,
return media_info last_access_ts=time_now_ms,
quarantined_by=None,
)
def _get_thumbnail_requirements( def _get_thumbnail_requirements(
self, media_type: str self, media_type: str

View file

@ -240,15 +240,14 @@ class UrlPreviewer:
cache_result = await self.store.get_url_cache(url, ts) cache_result = await self.store.get_url_cache(url, ts)
if ( if (
cache_result cache_result
and cache_result["expires_ts"] > ts and cache_result.expires_ts > ts
and cache_result["response_code"] / 100 == 2 and cache_result.response_code // 100 == 2
): ):
# It may be stored as text in the database, not as bytes (such as # It may be stored as text in the database, not as bytes (such as
# PostgreSQL). If so, encode it back before handing it on. # PostgreSQL). If so, encode it back before handing it on.
og = cache_result["og"] if isinstance(cache_result.og, str):
if isinstance(og, str): return cache_result.og.encode("utf8")
og = og.encode("utf8") return cache_result.og
return og
# If this URL can be accessed via an allowed oEmbed, use that instead. # If this URL can be accessed via an allowed oEmbed, use that instead.
url_to_download = url url_to_download = url

View file

@ -119,7 +119,7 @@ class ThumbnailResource(RestServlet):
if not media_info: if not media_info:
respond_404(request) respond_404(request)
return return
if media_info["quarantined_by"]: if media_info.quarantined_by:
logger.info("Media is quarantined") logger.info("Media is quarantined")
respond_404(request) respond_404(request)
return return
@ -134,7 +134,7 @@ class ThumbnailResource(RestServlet):
thumbnail_infos, thumbnail_infos,
media_id, media_id,
media_id, media_id,
url_cache=bool(media_info["url_cache"]), url_cache=bool(media_info.url_cache),
server_name=None, server_name=None,
) )
@ -152,7 +152,7 @@ class ThumbnailResource(RestServlet):
if not media_info: if not media_info:
respond_404(request) respond_404(request)
return return
if media_info["quarantined_by"]: if media_info.quarantined_by:
logger.info("Media is quarantined") logger.info("Media is quarantined")
respond_404(request) respond_404(request)
return return
@ -168,7 +168,7 @@ class ThumbnailResource(RestServlet):
file_info = FileInfo( file_info = FileInfo(
server_name=None, server_name=None,
file_id=media_id, file_id=media_id,
url_cache=media_info["url_cache"], url_cache=bool(media_info.url_cache),
thumbnail=info, thumbnail=info,
) )
@ -188,7 +188,7 @@ class ThumbnailResource(RestServlet):
desired_height, desired_height,
desired_method, desired_method,
desired_type, desired_type,
url_cache=bool(media_info["url_cache"]), url_cache=bool(media_info.url_cache),
) )
if file_path: if file_path:
@ -213,7 +213,7 @@ class ThumbnailResource(RestServlet):
server_name, media_id server_name, media_id
) )
file_id = media_info["filesystem_id"] file_id = media_info.filesystem_id
for info in thumbnail_infos: for info in thumbnail_infos:
t_w = info.width == desired_width t_w = info.width == desired_width
@ -224,7 +224,7 @@ class ThumbnailResource(RestServlet):
if t_w and t_h and t_method and t_type: if t_w and t_h and t_method and t_type:
file_info = FileInfo( file_info = FileInfo(
server_name=server_name, server_name=server_name,
file_id=media_info["filesystem_id"], file_id=file_id,
thumbnail=info, thumbnail=info,
) )
@ -280,7 +280,7 @@ class ThumbnailResource(RestServlet):
m_type, m_type,
thumbnail_infos, thumbnail_infos,
media_id, media_id,
media_info["filesystem_id"], media_info.filesystem_id,
url_cache=False, url_cache=False,
server_name=server_name, server_name=server_name,
) )

View file

@ -15,9 +15,7 @@
from enum import Enum from enum import Enum
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any,
Collection, Collection,
Dict,
Iterable, Iterable,
List, List,
Optional, Optional,
@ -54,11 +52,32 @@ class LocalMedia:
media_length: int media_length: int
upload_name: str upload_name: str
created_ts: int created_ts: int
url_cache: Optional[str]
last_access_ts: int last_access_ts: int
quarantined_by: Optional[str] quarantined_by: Optional[str]
safe_from_quarantine: bool safe_from_quarantine: bool
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RemoteMedia:
media_origin: str
media_id: str
media_type: str
media_length: int
upload_name: Optional[str]
filesystem_id: str
created_ts: int
last_access_ts: int
quarantined_by: Optional[str]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class UrlCache:
response_code: int
expires_ts: int
og: Union[str, bytes]
class MediaSortOrder(Enum): class MediaSortOrder(Enum):
""" """
Enum to define the sorting method used when returning media with Enum to define the sorting method used when returning media with
@ -165,13 +184,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name: str = hs.hostname self.server_name: str = hs.hostname
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: async def get_local_media(self, media_id: str) -> Optional[LocalMedia]:
"""Get the metadata for a local piece of media """Get the metadata for a local piece of media
Returns: Returns:
None if the media_id doesn't exist. None if the media_id doesn't exist.
""" """
return await self.db_pool.simple_select_one( row = await self.db_pool.simple_select_one(
"local_media_repository", "local_media_repository",
{"media_id": media_id}, {"media_id": media_id},
( (
@ -181,11 +200,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"created_ts", "created_ts",
"quarantined_by", "quarantined_by",
"url_cache", "url_cache",
"last_access_ts",
"safe_from_quarantine", "safe_from_quarantine",
), ),
allow_none=True, allow_none=True,
desc="get_local_media", desc="get_local_media",
) )
if row is None:
return None
return LocalMedia(media_id=media_id, **row)
async def get_local_media_by_user_paginate( async def get_local_media_by_user_paginate(
self, self,
@ -236,6 +259,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length, media_length,
upload_name, upload_name,
created_ts, created_ts,
url_cache,
last_access_ts, last_access_ts,
quarantined_by, quarantined_by,
safe_from_quarantine safe_from_quarantine
@ -257,9 +281,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length=row[2], media_length=row[2],
upload_name=row[3], upload_name=row[3],
created_ts=row[4], created_ts=row[4],
last_access_ts=row[5], url_cache=row[5],
quarantined_by=row[6], last_access_ts=row[6],
safe_from_quarantine=bool(row[7]), quarantined_by=row[7],
safe_from_quarantine=bool(row[8]),
) )
for row in txn for row in txn
] ]
@ -390,51 +415,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="mark_local_media_as_safe", desc="mark_local_media_as_safe",
) )
async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]: async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
"""Get the media_id and ts for a cached URL as of the given timestamp """Get the media_id and ts for a cached URL as of the given timestamp
Returns: Returns:
None if the URL isn't cached. None if the URL isn't cached.
""" """
def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]:
# get the most recently cached result (relative to the given ts) # get the most recently cached result (relative to the given ts)
sql = ( sql = """
"SELECT response_code, etag, expires_ts, og, media_id, download_ts" SELECT response_code, expires_ts, og
" FROM local_media_repository_url_cache" FROM local_media_repository_url_cache
" WHERE url = ? AND download_ts <= ?" WHERE url = ? AND download_ts <= ?
" ORDER BY download_ts DESC LIMIT 1" ORDER BY download_ts DESC LIMIT 1
) """
txn.execute(sql, (url, ts)) txn.execute(sql, (url, ts))
row = txn.fetchone() row = txn.fetchone()
if not row: if not row:
# ...or if we've requested a timestamp older than the oldest # ...or if we've requested a timestamp older than the oldest
# copy in the cache, return the oldest copy (if any) # copy in the cache, return the oldest copy (if any)
sql = ( sql = """
"SELECT response_code, etag, expires_ts, og, media_id, download_ts" SELECT response_code, expires_ts, og
" FROM local_media_repository_url_cache" FROM local_media_repository_url_cache
" WHERE url = ? AND download_ts > ?" WHERE url = ? AND download_ts > ?
" ORDER BY download_ts ASC LIMIT 1" ORDER BY download_ts ASC LIMIT 1
) """
txn.execute(sql, (url, ts)) txn.execute(sql, (url, ts))
row = txn.fetchone() row = txn.fetchone()
if not row: if not row:
return None return None
return dict( return UrlCache(response_code=row[0], expires_ts=row[1], og=row[2])
zip(
(
"response_code",
"etag",
"expires_ts",
"og",
"media_id",
"download_ts",
),
row,
)
)
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
@ -444,7 +457,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
response_code: int, response_code: int,
etag: Optional[str], etag: Optional[str],
expires_ts: int, expires_ts: int,
og: Optional[str], og: str,
media_id: str, media_id: str,
download_ts: int, download_ts: int,
) -> None: ) -> None:
@ -510,8 +523,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_cached_remote_media( async def get_cached_remote_media(
self, origin: str, media_id: str self, origin: str, media_id: str
) -> Optional[Dict[str, Any]]: ) -> Optional[RemoteMedia]:
return await self.db_pool.simple_select_one( row = await self.db_pool.simple_select_one(
"remote_media_cache", "remote_media_cache",
{"media_origin": origin, "media_id": media_id}, {"media_origin": origin, "media_id": media_id},
( (
@ -520,11 +533,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"upload_name", "upload_name",
"created_ts", "created_ts",
"filesystem_id", "filesystem_id",
"last_access_ts",
"quarantined_by", "quarantined_by",
), ),
allow_none=True, allow_none=True,
desc="get_cached_remote_media", desc="get_cached_remote_media",
) )
if row is None:
return row
return RemoteMedia(media_origin=origin, media_id=media_id, **row)
async def store_cached_remote_media( async def store_cached_remote_media(
self, self,
@ -623,10 +640,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
t_width: int, t_width: int,
t_height: int, t_height: int,
t_type: str, t_type: str,
) -> Optional[Dict[str, Any]]: ) -> Optional[ThumbnailInfo]:
"""Fetch the thumbnail info of given width, height and type.""" """Fetch the thumbnail info of given width, height and type."""
return await self.db_pool.simple_select_one( row = await self.db_pool.simple_select_one(
table="remote_media_cache_thumbnails", table="remote_media_cache_thumbnails",
keyvalues={ keyvalues={
"media_origin": origin, "media_origin": origin,
@ -641,11 +658,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"thumbnail_method", "thumbnail_method",
"thumbnail_type", "thumbnail_type",
"thumbnail_length", "thumbnail_length",
"filesystem_id",
), ),
allow_none=True, allow_none=True,
desc="get_remote_media_thumbnail", desc="get_remote_media_thumbnail",
) )
if row is None:
return None
return ThumbnailInfo(
width=row["thumbnail_width"],
height=row["thumbnail_height"],
method=row["thumbnail_method"],
type=row["thumbnail_type"],
length=row["thumbnail_length"],
)
@trace @trace
async def store_remote_media_thumbnail( async def store_remote_media_thumbnail(

View file

@ -504,7 +504,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
origin, media_id = self.media_id.split("/") origin, media_id = self.media_id.split("/")
info = self.get_success(self.store.get_cached_remote_media(origin, media_id)) info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
assert info is not None assert info is not None
file_id = info["filesystem_id"] file_id = info.filesystem_id
thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir( thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
origin, file_id origin, file_id

View file

@ -642,7 +642,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None assert media_info is not None
self.assertFalse(media_info["quarantined_by"]) self.assertFalse(media_info.quarantined_by)
# quarantining # quarantining
channel = self.make_request( channel = self.make_request(
@ -656,7 +656,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None assert media_info is not None
self.assertTrue(media_info["quarantined_by"]) self.assertTrue(media_info.quarantined_by)
# remove from quarantine # remove from quarantine
channel = self.make_request( channel = self.make_request(
@ -670,7 +670,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None assert media_info is not None
self.assertFalse(media_info["quarantined_by"]) self.assertFalse(media_info.quarantined_by)
def test_quarantine_protected_media(self) -> None: def test_quarantine_protected_media(self) -> None:
""" """
@ -683,7 +683,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
# verify protection # verify protection
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None assert media_info is not None
self.assertTrue(media_info["safe_from_quarantine"]) self.assertTrue(media_info.safe_from_quarantine)
# quarantining # quarantining
channel = self.make_request( channel = self.make_request(
@ -698,7 +698,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
# verify that is not in quarantine # verify that is not in quarantine
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None assert media_info is not None
self.assertFalse(media_info["quarantined_by"]) self.assertFalse(media_info.quarantined_by)
class ProtectMediaByIDTestCase(_AdminMediaTests): class ProtectMediaByIDTestCase(_AdminMediaTests):
@ -756,7 +756,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None assert media_info is not None
self.assertFalse(media_info["safe_from_quarantine"]) self.assertFalse(media_info.safe_from_quarantine)
# protect # protect
channel = self.make_request( channel = self.make_request(
@ -770,7 +770,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None assert media_info is not None
self.assertTrue(media_info["safe_from_quarantine"]) self.assertTrue(media_info.safe_from_quarantine)
# unprotect # unprotect
channel = self.make_request( channel = self.make_request(
@ -784,7 +784,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None assert media_info is not None
self.assertFalse(media_info["safe_from_quarantine"]) self.assertFalse(media_info.safe_from_quarantine)
class PurgeMediaCacheTestCase(_AdminMediaTests): class PurgeMediaCacheTestCase(_AdminMediaTests):

View file

@ -267,23 +267,23 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None: def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None:
"""Given an MXC URI, assert whether it has been purged or not.""" """Given an MXC URI, assert whether it has been purged or not."""
if mxc_uri.server_name == self.hs.config.server.server_name: if mxc_uri.server_name == self.hs.config.server.server_name:
found_media_dict = self.get_success( found_media = bool(
self.store.get_local_media(mxc_uri.media_id) self.get_success(self.store.get_local_media(mxc_uri.media_id))
) )
else: else:
found_media_dict = self.get_success( found_media = bool(
self.store.get_cached_remote_media( self.get_success(
mxc_uri.server_name, mxc_uri.media_id self.store.get_cached_remote_media(
mxc_uri.server_name, mxc_uri.media_id
)
) )
) )
if expect_purged: if expect_purged:
self.assertIsNone( self.assertFalse(found_media, msg=f"{mxc_uri} unexpectedly not purged")
found_media_dict, msg=f"{mxc_uri} unexpectedly not purged"
)
else: else:
self.assertIsNotNone( self.assertTrue(
found_media_dict, found_media,
msg=f"{mxc_uri} unexpectedly purged", msg=f"{mxc_uri} unexpectedly purged",
) )