From d34c6e1279a24c5eb8afb962a29950c85fbfbf8a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 15 Jan 2021 10:57:37 -0500 Subject: [PATCH] Add type hints to media rest resources. (#9093) --- changelog.d/9093.misc | 1 + synapse/rest/media/v1/_base.py | 76 ++++++++++------- synapse/rest/media/v1/config_resource.py | 14 +++- synapse/rest/media/v1/download_resource.py | 18 +++-- synapse/rest/media/v1/filepath.py | 50 ++++++++---- synapse/rest/media/v1/media_repository.py | 50 +++++++----- synapse/rest/media/v1/media_storage.py | 12 +-- synapse/rest/media/v1/preview_url_resource.py | 77 +++++++++++------- synapse/rest/media/v1/storage_provider.py | 37 +++++---- synapse/rest/media/v1/thumbnail_resource.py | 81 ++++++++++++------- synapse/rest/media/v1/thumbnailer.py | 18 +++-- synapse/rest/media/v1/upload_resource.py | 14 +++- .../databases/main/media_repository.py | 3 +- 13 files changed, 286 insertions(+), 165 deletions(-) create mode 100644 changelog.d/9093.misc diff --git a/changelog.d/9093.misc b/changelog.d/9093.misc new file mode 100644 index 0000000000..53eb8f72a8 --- /dev/null +++ b/changelog.d/9093.misc @@ -0,0 +1 @@ +Add type hints to media repository. diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 47c2b44bff..31a41e4a27 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019 New Vector Ltd +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,10 +17,11 @@ import logging import os import urllib -from typing import Awaitable +from typing import Awaitable, Dict, Generator, List, Optional, Tuple from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender +from twisted.web.http import Request from synapse.api.errors import Codes, SynapseError, cs_error from synapse.http.server import finish_request, respond_with_json @@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [ ] -def parse_media_id(request): +def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: try: # This allows users to append e.g. /test.png to the URL. Useful for # clients that parse the URL to see content type. @@ -69,7 +70,7 @@ def parse_media_id(request): ) -def respond_404(request): +def respond_404(request: Request) -> None: respond_with_json( request, 404, @@ -79,8 +80,12 @@ def respond_404(request): async def respond_with_file( - request, media_type, file_path, file_size=None, upload_name=None -): + request: Request, + media_type: str, + file_path: str, + file_size: Optional[int] = None, + upload_name: Optional[str] = None, +) -> None: logger.debug("Responding with %r", file_path) if os.path.isfile(file_path): @@ -98,15 +103,20 @@ async def respond_with_file( respond_404(request) -def add_file_headers(request, media_type, file_size, upload_name): +def add_file_headers( + request: Request, + media_type: str, + file_size: Optional[int], + upload_name: Optional[str], +) -> None: """Adds the correct response headers in preparation for responding with the media. Args: - request (twisted.web.http.Request) - media_type (str): The media/content type. - file_size (int): Size in bytes of the media, if known. - upload_name (str): The name of the requested file, if any. + request + media_type: The media/content type. + file_size: Size in bytes of the media, if known. + upload_name: The name of the requested file, if any. """ def _quote(x): @@ -153,7 +163,8 @@ def add_file_headers(request, media_type, file_size, upload_name): # select private. don't bother setting Expires as all our # clients are smart enough to be happy with Cache-Control request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") - request.setHeader(b"Content-Length", b"%d" % (file_size,)) + if file_size is not None: + request.setHeader(b"Content-Length", b"%d" % (file_size,)) # Tell web crawlers to not index, archive, or follow links in media. This # should help to prevent things in the media repo from showing up in web @@ -184,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = { } -def _can_encode_filename_as_token(x): +def _can_encode_filename_as_token(x: str) -> bool: for c in x: # from RFC2616: # @@ -206,17 +217,21 @@ def _can_encode_filename_as_token(x): async def respond_with_responder( - request, responder, media_type, file_size, upload_name=None -): + request: Request, + responder: "Optional[Responder]", + media_type: str, + file_size: Optional[int], + upload_name: Optional[str] = None, +) -> None: """Responds to the request with given responder. If responder is None then returns 404. Args: - request (twisted.web.http.Request) - responder (Responder|None) - media_type (str): The media/content type. - file_size (int|None): Size in bytes of the media. If not known it should be None - upload_name (str|None): The name of the requested file, if any. + request + responder + media_type: The media/content type. + file_size: Size in bytes of the media. If not known it should be None + upload_name: The name of the requested file, if any. """ if request._disconnected: logger.warning( @@ -308,22 +323,22 @@ class FileInfo: self.thumbnail_type = thumbnail_type -def get_filename_from_headers(headers): +def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: """ Get the filename of the downloaded file by inspecting the Content-Disposition HTTP header. Args: - headers (dict[bytes, list[bytes]]): The HTTP request headers. + headers: The HTTP request headers. Returns: - A Unicode string of the filename, or None. + The filename, or None. """ content_disposition = headers.get(b"Content-Disposition", [b""]) # No header, bail out. if not content_disposition[0]: - return + return None _, params = _parse_header(content_disposition[0]) @@ -356,17 +371,16 @@ def get_filename_from_headers(headers): return upload_name -def _parse_header(line): +def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]: """Parse a Content-type like header. Cargo-culted from `cgi`, but works on bytes rather than strings. Args: - line (bytes): header to be parsed + line: header to be parsed Returns: - Tuple[bytes, dict[bytes, bytes]]: - the main content-type, followed by the parameter dictionary + The main content-type, followed by the parameter dictionary """ parts = _parseparam(b";" + line) key = next(parts) @@ -386,16 +400,16 @@ def _parse_header(line): return key, pdict -def _parseparam(s): +def _parseparam(s: bytes) -> Generator[bytes, None, None]: """Generator which splits the input on ;, respecting double-quoted sequences Cargo-culted from `cgi`, but works on bytes rather than strings. Args: - s (bytes): header to be parsed + s: header to be parsed Returns: - Iterable[bytes]: the split input + The split input """ while s[:1] == b";": s = s[1:] diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index 68dd2a1c8a..4e4c6971f7 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 Will Hunt +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,22 +15,29 @@ # limitations under the License. # +from typing import TYPE_CHECKING + +from twisted.web.http import Request + from synapse.http.server import DirectServeJsonResource, respond_with_json +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + class MediaConfigResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() config = hs.get_config() self.clock = hs.get_clock() self.auth = hs.get_auth() self.limits_dict = {"m.upload.size": config.max_upload_size} - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: await self.auth.get_user_by_req(request) respond_with_json(request, 200, self.limits_dict, send_cors=True) - async def _async_render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request: Request) -> None: respond_with_json(request, 200, {}, send_cors=True) diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index d3d8457303..3ed219ae43 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,24 +14,31 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request -import synapse.http.servlet from synapse.http.server import DirectServeJsonResource, set_cors_headers +from synapse.http.servlet import parse_boolean from ._base import parse_media_id, respond_404 +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) class DownloadResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo): + def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): super().__init__() self.media_repo = media_repo self.server_name = hs.hostname - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: set_cors_headers(request) request.setHeader( b"Content-Security-Policy", @@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource): if server_name == self.server_name: await self.media_repo.get_local_media(request, media_id, name) else: - allow_remote = synapse.http.servlet.parse_boolean( - request, "allow_remote", default=True - ) + allow_remote = parse_boolean(request, "allow_remote", default=True) if not allow_remote: logger.info( "Rejecting request for remote media %s/%s due to allow_remote", diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index 9e079f672f..7792f26e78 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,11 +17,12 @@ import functools import os import re +from typing import Callable, List NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") -def _wrap_in_base_path(func): +def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]": """Takes a function that returns a relative path and turns it into an absolute path based on the location of the primary media store """ @@ -41,12 +43,18 @@ class MediaFilePaths: to write to the backup media store (when one is configured) """ - def __init__(self, primary_base_path): + def __init__(self, primary_base_path: str): self.base_path = primary_base_path def default_thumbnail_rel( - self, default_top_level, default_sub_type, width, height, content_type, method - ): + self, + default_top_level: str, + default_sub_type: str, + width: int, + height: int, + content_type: str, + method: str, + ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( @@ -55,12 +63,14 @@ class MediaFilePaths: default_thumbnail = _wrap_in_base_path(default_thumbnail_rel) - def local_media_filepath_rel(self, media_id): + def local_media_filepath_rel(self, media_id: str) -> str: return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:]) local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) - def local_media_thumbnail_rel(self, media_id, width, height, content_type, method): + def local_media_thumbnail_rel( + self, media_id: str, width: int, height: int, content_type: str, method: str + ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( @@ -86,7 +96,7 @@ class MediaFilePaths: media_id[4:], ) - def remote_media_filepath_rel(self, server_name, file_id): + def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: return os.path.join( "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] ) @@ -94,8 +104,14 @@ class MediaFilePaths: remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) def remote_media_thumbnail_rel( - self, server_name, file_id, width, height, content_type, method - ): + self, + server_name: str, + file_id: str, + width: int, + height: int, + content_type: str, + method: str, + ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( @@ -113,7 +129,7 @@ class MediaFilePaths: # Should be removed after some time, when most of the thumbnails are stored # using the new path. def remote_media_thumbnail_rel_legacy( - self, server_name, file_id, width, height, content_type + self, server_name: str, file_id: str, width: int, height: int, content_type: str ): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) @@ -126,7 +142,7 @@ class MediaFilePaths: file_name, ) - def remote_media_thumbnail_dir(self, server_name, file_id): + def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: return os.path.join( self.base_path, "remote_thumbnail", @@ -136,7 +152,7 @@ class MediaFilePaths: file_id[4:], ) - def url_cache_filepath_rel(self, media_id): + def url_cache_filepath_rel(self, media_id: str) -> str: if NEW_FORMAT_ID_RE.match(media_id): # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -146,7 +162,7 @@ class MediaFilePaths: url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) - def url_cache_filepath_dirs_to_delete(self, media_id): + def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id file" if NEW_FORMAT_ID_RE.match(media_id): return [os.path.join(self.base_path, "url_cache", media_id[:10])] @@ -156,7 +172,9 @@ class MediaFilePaths: os.path.join(self.base_path, "url_cache", media_id[0:2]), ] - def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method): + def url_cache_thumbnail_rel( + self, media_id: str, width: int, height: int, content_type: str, method: str + ) -> str: # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -178,7 +196,7 @@ class MediaFilePaths: url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) - def url_cache_thumbnail_directory(self, media_id): + def url_cache_thumbnail_directory(self, media_id: str) -> str: # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -195,7 +213,7 @@ class MediaFilePaths: media_id[4:], ) - def url_cache_thumbnail_dirs_to_delete(self, media_id): + def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id thumbnails" # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 83beb02b05..4c9946a616 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import errno import logging import os import shutil -from typing import IO, Dict, List, Optional, Tuple +from io import BytesIO +from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple import twisted.internet.error import twisted.web.http @@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource from .thumbnailer import Thumbnailer, ThumbnailError from .upload_resource import UploadResource +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -63,7 +66,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 class MediaRepository: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.client = hs.get_federation_http_client() @@ -73,16 +76,16 @@ class MediaRepository: self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels - self.primary_base_path = hs.config.media_store_path - self.filepaths = MediaFilePaths(self.primary_base_path) + self.primary_base_path = hs.config.media_store_path # type: str + self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") - self.recently_accessed_remotes = set() - self.recently_accessed_locals = set() + self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]] + self.recently_accessed_locals = set() # type: Set[str] self.federation_domain_whitelist = hs.config.federation_domain_whitelist @@ -113,7 +116,7 @@ class MediaRepository: "update_recently_accessed_media", self._update_recently_accessed ) - async def _update_recently_accessed(self): + async def _update_recently_accessed(self) -> None: remote_media = self.recently_accessed_remotes self.recently_accessed_remotes = set() @@ -124,12 +127,12 @@ class MediaRepository: local_media, remote_media, self.clock.time_msec() ) - def mark_recently_accessed(self, server_name, media_id): + def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None: """Mark the given media as recently accessed. Args: - server_name (str|None): Origin server of media, or None if local - media_id (str): The media ID of the content + server_name: Origin server of media, or None if local + media_id: The media ID of the content """ if server_name: self.recently_accessed_remotes.add((server_name, media_id)) @@ -459,7 +462,14 @@ class MediaRepository: def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) - def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): + def _generate_thumbnail( + self, + thumbnailer: Thumbnailer, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + ) -> Optional[BytesIO]: m_width = thumbnailer.width m_height = thumbnailer.height @@ -470,22 +480,20 @@ class MediaRepository: m_height, self.max_image_pixels, ) - return + return None if thumbnailer.transpose_method is not None: m_width, m_height = thumbnailer.transpose() if t_method == "crop": - t_byte_source = thumbnailer.crop(t_width, t_height, t_type) + return thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) - t_byte_source = thumbnailer.scale(t_width, t_height, t_type) - else: - t_byte_source = None + return thumbnailer.scale(t_width, t_height, t_type) - return t_byte_source + return None async def generate_local_exact_thumbnail( self, @@ -776,7 +784,7 @@ class MediaRepository: return {"width": m_width, "height": m_height} - async def delete_old_remote_media(self, before_ts): + async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]: old_media = await self.store.get_remote_media_before(before_ts) deleted = 0 @@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource): within a given rectangle. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): # If we're not configured to use it, raise if we somehow got here. if not hs.config.can_load_media_repo: raise ConfigError("Synapse is not configured to use a media repo.") diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 268e0c8f50..89cdd605aa 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vecotr Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ import os import shutil from typing import IO, TYPE_CHECKING, Any, Optional, Sequence +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender from synapse.logging.context import defer_to_thread, make_deferred_yieldable @@ -270,7 +272,7 @@ class MediaStorage: return self.filepaths.local_media_filepath_rel(file_info.file_id) -def _write_file_synchronously(source, dest): +def _write_file_synchronously(source: IO, dest: IO) -> None: """Write `source` to the file like `dest` synchronously. Should be called from a thread. @@ -286,14 +288,14 @@ class FileResponder(Responder): """Wraps an open file that can be sent to a request. Args: - open_file (file): A file like object to be streamed ot the client, + open_file: A file like object to be streamed ot the client, is closed when finished streaming. """ - def __init__(self, open_file): + def __init__(self, open_file: IO): self.open_file = open_file - def write_to_consumer(self, consumer): + def write_to_consumer(self, consumer: IConsumer) -> Deferred: return make_deferred_yieldable( FileSender().beginFileTransfer(self.open_file, consumer) ) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 1082389d9b..a632099167 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import datetime import errno import fnmatch @@ -23,12 +23,13 @@ import re import shutil import sys import traceback -from typing import Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union from urllib import parse as urlparse import attr from twisted.internet.error import DNSLookupError +from twisted.web.http import Request from synapse.api.errors import Codes, SynapseError from synapse.http.client import SimpleHttpClient @@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers +from synapse.rest.media.v1.media_storage import MediaStorage from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache @@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string from ._base import FileInfo +if TYPE_CHECKING: + from lxml import etree + + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I) @@ -119,7 +127,12 @@ class OEmbedError(Exception): class PreviewUrlResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo, media_storage): + def __init__( + self, + hs: "HomeServer", + media_repo: "MediaRepository", + media_storage: MediaStorage, + ): super().__init__() self.auth = hs.get_auth() @@ -166,11 +179,11 @@ class PreviewUrlResource(DirectServeJsonResource): self._start_expire_url_cache_data, 10 * 1000 ) - async def _async_render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request: Request) -> None: request.setHeader(b"Allow", b"OPTIONS, GET") respond_with_json(request, 200, {}, send_cors=True) - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: # XXX: if get_user_by_req fails, what should we do in an async render? requester = await self.auth.get_user_by_req(request) @@ -450,7 +463,7 @@ class PreviewUrlResource(DirectServeJsonResource): logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) raise OEmbedError() from e - async def _download_url(self, url: str, user): + async def _download_url(self, url: str, user: str) -> Dict[str, Any]: # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -580,7 +593,7 @@ class PreviewUrlResource(DirectServeJsonResource): "expire_url_cache_data", self._expire_url_cache_data ) - async def _expire_url_cache_data(self): + async def _expire_url_cache_data(self) -> None: """Clean up expired url cache content, media and thumbnails. """ # TODO: Delete from backup media store @@ -676,7 +689,9 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("No media removed from url cache") -def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]: +def decode_and_calc_og( + body: bytes, media_uri: str, request_encoding: Optional[str] = None +) -> Dict[str, Optional[str]]: # If there's no body, nothing useful is going to be found. if not body: return {} @@ -697,7 +712,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str] return og -def _calc_og(tree, media_uri): +def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]: # suck our tree into lxml and define our OG response. # if we see any image URLs in the OG response, then spider them @@ -801,7 +816,9 @@ def _calc_og(tree, media_uri): for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) ) og["og:description"] = summarize_paragraphs(text_nodes) - else: + elif og["og:description"]: + # This must be a non-empty string at this point. + assert isinstance(og["og:description"], str) og["og:description"] = summarize_paragraphs([og["og:description"]]) # TODO: delete the url downloads to stop diskfilling, @@ -809,7 +826,9 @@ def _calc_og(tree, media_uri): return og -def _iterate_over_text(tree, *tags_to_ignore): +def _iterate_over_text( + tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]] +) -> Generator[str, None, None]: """Iterate over the tree returning text nodes in a depth first fashion, skipping text nodes inside certain tags. """ @@ -843,32 +862,32 @@ def _iterate_over_text(tree, *tags_to_ignore): ) -def _rebase_url(url, base): - base = list(urlparse.urlparse(base)) - url = list(urlparse.urlparse(url)) - if not url[0]: # fix up schema - url[0] = base[0] or "http" - if not url[1]: # fix up hostname - url[1] = base[1] - if not url[2].startswith("/"): - url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2] - return urlparse.urlunparse(url) +def _rebase_url(url: str, base: str) -> str: + base_parts = list(urlparse.urlparse(base)) + url_parts = list(urlparse.urlparse(url)) + if not url_parts[0]: # fix up schema + url_parts[0] = base_parts[0] or "http" + if not url_parts[1]: # fix up hostname + url_parts[1] = base_parts[1] + if not url_parts[2].startswith("/"): + url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2] + return urlparse.urlunparse(url_parts) -def _is_media(content_type): - if content_type.lower().startswith("image/"): - return True +def _is_media(content_type: str) -> bool: + return content_type.lower().startswith("image/") -def _is_html(content_type): +def _is_html(content_type: str) -> bool: content_type = content_type.lower() - if content_type.startswith("text/html") or content_type.startswith( + return content_type.startswith("text/html") or content_type.startswith( "application/xhtml" - ): - return True + ) -def summarize_paragraphs(text_nodes, min_size=200, max_size=500): +def summarize_paragraphs( + text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500 +) -> Optional[str]: # Try to get a summary of between 200 and 500 words, respecting # first paragraph and then word boundaries. # TODO: Respect sentences? diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 67f67efde7..e92006faa9 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging import os import shutil -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background @@ -27,13 +28,17 @@ from .media_storage import FileResponder logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer -class StorageProvider: + +class StorageProvider(metaclass=abc.ABCMeta): """A storage provider is a service that can store uploaded media and retrieve them. """ - async def store_file(self, path: str, file_info: FileInfo): + @abc.abstractmethod + async def store_file(self, path: str, file_info: FileInfo) -> None: """Store the file described by file_info. The actual contents can be retrieved by reading the file in file_info.upload_path. @@ -42,6 +47,7 @@ class StorageProvider: file_info: The metadata of the file. """ + @abc.abstractmethod async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """Attempt to fetch the file described by file_info and stream it into writer. @@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider): self.store_synchronous = store_synchronous self.store_remote = store_remote - def __str__(self): + def __str__(self) -> str: return "StorageProviderWrapper[%s]" % (self.backend,) - async def store_file(self, path, file_info): + async def store_file(self, path: str, file_info: FileInfo) -> None: if not file_info.server_name and not self.store_local: return None @@ -91,7 +97,7 @@ class StorageProviderWrapper(StorageProvider): if self.store_synchronous: # store_file is supposed to return an Awaitable, but guard # against improper implementations. - return await maybe_awaitable(self.backend.store_file(path, file_info)) + await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore else: # TODO: Handle errors. async def store(): @@ -103,9 +109,8 @@ class StorageProviderWrapper(StorageProvider): logger.exception("Error storing file") run_in_background(store) - return None - async def fetch(self, path, file_info): + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: # store_file is supposed to return an Awaitable, but guard # against improper implementations. return await maybe_awaitable(self.backend.fetch(path, file_info)) @@ -115,11 +120,11 @@ class FileStorageProviderBackend(StorageProvider): """A storage provider that stores files in a directory on a filesystem. Args: - hs (HomeServer) + hs config: The config returned by `parse_config`. """ - def __init__(self, hs, config): + def __init__(self, hs: "HomeServer", config: str): self.hs = hs self.cache_directory = hs.config.media_store_path self.base_directory = config @@ -127,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider): def __str__(self): return "FileStorageProviderBackend[%s]" % (self.base_directory,) - async def store_file(self, path, file_info): + async def store_file(self, path: str, file_info: FileInfo) -> None: """See StorageProvider.store_file""" primary_fname = os.path.join(self.cache_directory, path) @@ -137,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider): if not os.path.exists(dirname): os.makedirs(dirname) - return await defer_to_thread( + await defer_to_thread( self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname ) - async def fetch(self, path, file_info): + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """See StorageProvider.fetch""" backup_fname = os.path.join(self.base_directory, path) if os.path.isfile(backup_fname): return FileResponder(open(backup_fname, "rb")) + return None + @staticmethod - def parse_config(config): + def parse_config(config: dict) -> str: """Called on startup to parse config supplied. This should parse the config and raise if there is a problem. diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 30421b663a..d6880f2e6e 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,10 +16,14 @@ import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request from synapse.api.errors import SynapseError from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.servlet import parse_integer, parse_string +from synapse.rest.media.v1.media_storage import MediaStorage from ._base import ( FileInfo, @@ -28,13 +33,22 @@ from ._base import ( respond_with_responder, ) +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) class ThumbnailResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo, media_storage): + def __init__( + self, + hs: "HomeServer", + media_repo: "MediaRepository", + media_storage: MediaStorage, + ): super().__init__() self.store = hs.get_datastore() @@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource): self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.server_name = hs.hostname - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: set_cors_headers(request) server_name, media_id, _ = parse_media_id(request) width = parse_integer(request, "width", required=True) @@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource): self.media_repo.mark_recently_accessed(server_name, media_id) async def _respond_local_thumbnail( - self, request, media_id, width, height, method, m_type - ): + self, + request: Request, + media_id: str, + width: int, + height: int, + method: str, + m_type: str, + ) -> None: media_info = await self.store.get_local_media(media_id) if not media_info: @@ -114,13 +134,13 @@ class ThumbnailResource(DirectServeJsonResource): async def _select_or_generate_local_thumbnail( self, - request, - media_id, - desired_width, - desired_height, - desired_method, - desired_type, - ): + request: Request, + media_id: str, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + ) -> None: media_info = await self.store.get_local_media(media_id) if not media_info: @@ -178,14 +198,14 @@ class ThumbnailResource(DirectServeJsonResource): async def _select_or_generate_remote_thumbnail( self, - request, - server_name, - media_id, - desired_width, - desired_height, - desired_method, - desired_type, - ): + request: Request, + server_name: str, + media_id: str, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + ) -> None: media_info = await self.media_repo.get_remote_media_info(server_name, media_id) thumbnail_infos = await self.store.get_remote_media_thumbnails( @@ -239,8 +259,15 @@ class ThumbnailResource(DirectServeJsonResource): raise SynapseError(400, "Failed to generate thumbnail.") async def _respond_remote_thumbnail( - self, request, server_name, media_id, width, height, method, m_type - ): + self, + request: Request, + server_name: str, + media_id: str, + width: int, + height: int, + method: str, + m_type: str, + ) -> None: # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead of # downloading the remote file and generating our own thumbnails. @@ -275,12 +302,12 @@ class ThumbnailResource(DirectServeJsonResource): def _select_thumbnail( self, - desired_width, - desired_height, - desired_method, - desired_type, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, thumbnail_infos, - ): + ) -> dict: d_w = desired_width d_h = desired_height diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 32a8e4f960..07903e4017 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +15,7 @@ # limitations under the License. import logging from io import BytesIO +from typing import Tuple from PIL import Image @@ -39,7 +41,7 @@ class Thumbnailer: FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} - def __init__(self, input_path): + def __init__(self, input_path: str): try: self.image = Image.open(input_path) except OSError as e: @@ -59,11 +61,11 @@ class Thumbnailer: # A lot of parsing errors can happen when parsing EXIF logger.info("Error parsing image EXIF information: %s", e) - def transpose(self): + def transpose(self) -> Tuple[int, int]: """Transpose the image using its EXIF Orientation tag Returns: - Tuple[int, int]: (width, height) containing the new image size in pixels. + A tuple containing the new image size in pixels as (width, height). """ if self.transpose_method is not None: self.image = self.image.transpose(self.transpose_method) @@ -73,7 +75,7 @@ class Thumbnailer: self.image.info["exif"] = None return self.image.size - def aspect(self, max_width, max_height): + def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]: """Calculate the largest size that preserves aspect ratio which fits within the given rectangle:: @@ -91,7 +93,7 @@ class Thumbnailer: else: return (max_height * self.width) // self.height, max_height - def _resize(self, width, height): + def _resize(self, width: int, height: int) -> Image: # 1-bit or 8-bit color palette images need converting to RGB # otherwise they will be scaled using nearest neighbour which # looks awful @@ -99,7 +101,7 @@ class Thumbnailer: self.image = self.image.convert("RGB") return self.image.resize((width, height), Image.ANTIALIAS) - def scale(self, width, height, output_type): + def scale(self, width: int, height: int, output_type: str) -> BytesIO: """Rescales the image to the given dimensions. Returns: @@ -108,7 +110,7 @@ class Thumbnailer: scaled = self._resize(width, height) return self._encode_image(scaled, output_type) - def crop(self, width, height, output_type): + def crop(self, width: int, height: int, output_type: str) -> BytesIO: """Rescales and crops the image to the given dimensions preserving aspect:: (w_in / h_in) = (w_scaled / h_scaled) @@ -136,7 +138,7 @@ class Thumbnailer: cropped = scaled_image.crop((crop_left, 0, crop_right, height)) return self._encode_image(cropped, output_type) - def _encode_image(self, output_image, output_type): + def _encode_image(self, output_image: Image, output_type: str) -> BytesIO: output_bytes_io = BytesIO() fmt = self.FORMATS[output_type] if fmt == "JPEG": diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 42febc9afc..6da76ae994 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,18 +15,25 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_string +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) class UploadResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo): + def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): super().__init__() self.media_repo = media_repo @@ -37,10 +45,10 @@ class UploadResource(DirectServeJsonResource): self.max_upload_size = hs.config.max_upload_size self.clock = hs.get_clock() - async def _async_render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request: Request) -> None: respond_with_json(request, 200, {}, send_cors=True) - async def _async_render_POST(self, request): + async def _async_render_POST(self, request: Request) -> None: requester = await self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 4b2f224718..283c8a5e22 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -169,7 +170,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def get_local_media_before( self, before_ts: int, size_gt: int, keep_profiles: bool, - ) -> Optional[List[str]]: + ) -> List[str]: # to find files that have never been accessed (last_access_ts IS NULL) # compare with `created_ts`