diff --git a/synapse/media/v1/base_resource.py b/synapse/media/v1/base_resource.py index 77b05c6548..14735ff375 100644 --- a/synapse/media/v1/base_resource.py +++ b/synapse/media/v1/base_resource.py @@ -45,6 +45,7 @@ class BaseMediaResource(Resource): self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.filepaths = filepaths + self.downloads = {} @staticmethod def catch_errors(request_handler): @@ -128,6 +129,28 @@ class BaseMediaResource(Resource): if not os.path.exists(dirname): os.makedirs(dirname) + def _get_remote_media(self, server_name, media_id): + key = (server_name, media_id) + download = self.downloads.get(key) + if download is None: + download = self._get_remote_media_impl(server_name, media_id) + self.downloads[key] = download + @download.addBoth + def callback(media_info): + del self.downloads[key] + return download + + @defer.inlineCallbacks + def _get_remote_media_impl(self, server_name, media_id): + media_info = yield self.store.get_cached_remote_media( + server_name, media_id + ) + if not media_info: + media_info = yield self._download_remote_file( + server_name, media_id + ) + defer.returnValue(media_info) + @defer.inlineCallbacks def _download_remote_file(self, server_name, media_id): file_id = random_string(24) @@ -231,7 +254,7 @@ class BaseMediaResource(Resource): if m_width * m_height >= self.max_image_pixels: logger.info( - "Image too large to thumbnail %r x %r > %r" + "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return @@ -294,7 +317,7 @@ class BaseMediaResource(Resource): if m_width * m_height >= self.max_image_pixels: logger.info( - "Image too large to thumbnail %r x %r > %r" + "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return diff --git a/synapse/media/v1/download_resource.py b/synapse/media/v1/download_resource.py index 6de0932ba3..f3a6804e05 100644 --- a/synapse/media/v1/download_resource.py +++ b/synapse/media/v1/download_resource.py @@ -56,14 +56,7 @@ class DownloadResource(BaseMediaResource): @defer.inlineCallbacks def _respond_remote_file(self, request, server_name, media_id): - media_info = yield self.store.get_cached_remote_media( - server_name, media_id - ) - - if not media_info: - media_info = yield self._download_remote_file( - server_name, media_id - ) + media_info = yield self._get_remote_media(server_name, media_id) media_type = media_info["media_type"] filesystem_id = media_info["filesystem_id"] diff --git a/synapse/media/v1/thumbnail_resource.py b/synapse/media/v1/thumbnail_resource.py index fd08c7ecd2..e19620d456 100644 --- a/synapse/media/v1/thumbnail_resource.py +++ b/synapse/media/v1/thumbnail_resource.py @@ -83,16 +83,9 @@ class ThumbnailResource(BaseMediaResource): @defer.inlineCallbacks def _respond_remote_thumbnail(self, request, server_name, media_id, width, height, method, m_type): - media_info = yield self.store.get_cached_remote_media( - server_name, media_id - ) - - if not media_info: - # TODO: Don't download the whole remote file - # We should proxy the thumbnail from the remote server instead. - media_info = yield self._download_remote_file( - server_name, media_id - ) + # TODO: Don't download the whole remote file + # We should proxy the thumbnail from the remote server instead. + media_info = yield self._get_remote_media(server_name, media_id) thumbnail_infos = yield self.store.get_remote_media_thumbnails( server_name, media_id,