Merge pull request #6664 from matrix-org/erikj/media_admin_apis

Fix media repo admin APIs when using a media worker.
This commit is contained in:
Erik Johnston 2020-01-08 15:50:06 +00:00 committed by GitHub
commit 7c232bd98b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 131 additions and 120 deletions

1
changelog.d/6664.bugfix Normal file
View file

@ -0,0 +1 @@
Fix media repo admin APIs when using a media worker.

View file

@ -34,6 +34,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.admin import register_servlets_for_media_repo
@ -47,6 +48,7 @@ logger = logging.getLogger("synapse.app.media_repository")
class MediaRepositorySlavedStore( class MediaRepositorySlavedStore(
RoomStore,
SlavedApplicationServiceStore, SlavedApplicationServiceStore,
SlavedRegistrationStore, SlavedRegistrationStore,
SlavedClientIpStore, SlavedClientIpStore,

View file

@ -366,6 +366,134 @@ class RoomWorkerStore(SQLBaseStore):
defer.returnValue(row) defer.returnValue(row)
def get_media_mxcs_in_room(self, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
room_id (str)
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
def _get_media_mxcs_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
local_media_mxcs = []
remote_media_mxcs = []
# Convert the IDs to MXC URIs
for media_id in local_mxcs:
local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
for hostname, media_id in remote_mxcs:
remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
return local_media_mxcs, remote_media_mxcs
return self.db.runInteraction(
"get_media_ids_in_room", _get_media_mxcs_in_room_txn
)
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines
the associated media
"""
def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
total_media_quarantined = 0
# Now update all the tables to set the quarantined_by flag
txn.executemany(
"""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
""",
((quarantined_by, media_id) for media_id in local_mxcs),
)
txn.executemany(
"""
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE media_origin = ? AND media_id = ?
""",
(
(quarantined_by, origin, media_id)
for origin, media_id in remote_mxcs
),
)
total_media_quarantined += len(local_mxcs)
total_media_quarantined += len(remote_mxcs)
return total_media_quarantined
return self.db.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
def _get_media_mxcs_in_room_txn(self, txn, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
txn (cursor)
room_id (str)
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
sql = """
SELECT stream_ordering, json FROM events
JOIN event_json USING (room_id, event_id)
WHERE room_id = ?
%(where_clause)s
AND contains_url = ? AND outlier = ?
ORDER BY stream_ordering DESC
LIMIT ?
"""
txn.execute(sql % {"where_clause": ""}, (room_id, True, False, 100))
local_media_mxcs = []
remote_media_mxcs = []
while True:
next_token = None
for stream_ordering, content_json in txn:
next_token = stream_ordering
event_json = json.loads(content_json)
content = event_json["content"]
content_url = content.get("url")
thumbnail_url = content.get("info", {}).get("thumbnail_url")
for url in (content_url, thumbnail_url):
if not url:
continue
matches = mxc_re.match(url)
if matches:
hostname = matches.group(1)
media_id = matches.group(2)
if hostname == self.hs.hostname:
local_media_mxcs.append(media_id)
else:
remote_media_mxcs.append((hostname, media_id))
if next_token is None:
# We've gone through the whole room, so we're finished.
break
txn.execute(
sql % {"where_clause": "AND stream_ordering < ?"},
(room_id, next_token, True, False, 100),
)
return local_media_mxcs, remote_media_mxcs
class RoomBackgroundUpdateStore(SQLBaseStore): class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@ -810,126 +938,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
(room_id,), (room_id,),
) )
def get_media_mxcs_in_room(self, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
room_id (str)
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
def _get_media_mxcs_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
local_media_mxcs = []
remote_media_mxcs = []
# Convert the IDs to MXC URIs
for media_id in local_mxcs:
local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
for hostname, media_id in remote_mxcs:
remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
return local_media_mxcs, remote_media_mxcs
return self.db.runInteraction(
"get_media_ids_in_room", _get_media_mxcs_in_room_txn
)
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines
the associated media
"""
def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
total_media_quarantined = 0
# Now update all the tables to set the quarantined_by flag
txn.executemany(
"""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
""",
((quarantined_by, media_id) for media_id in local_mxcs),
)
txn.executemany(
"""
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE media_origin = ? AND media_id = ?
""",
(
(quarantined_by, origin, media_id)
for origin, media_id in remote_mxcs
),
)
total_media_quarantined += len(local_mxcs)
total_media_quarantined += len(remote_mxcs)
return total_media_quarantined
return self.db.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
def _get_media_mxcs_in_room_txn(self, txn, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
txn (cursor)
room_id (str)
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
next_token = self.get_current_events_token() + 1
local_media_mxcs = []
remote_media_mxcs = []
while next_token:
sql = """
SELECT stream_ordering, json FROM events
JOIN event_json USING (room_id, event_id)
WHERE room_id = ?
AND stream_ordering < ?
AND contains_url = ? AND outlier = ?
ORDER BY stream_ordering DESC
LIMIT ?
"""
txn.execute(sql, (room_id, next_token, True, False, 100))
next_token = None
for stream_ordering, content_json in txn:
next_token = stream_ordering
event_json = json.loads(content_json)
content = event_json["content"]
content_url = content.get("url")
thumbnail_url = content.get("info", {}).get("thumbnail_url")
for url in (content_url, thumbnail_url):
if not url:
continue
matches = mxc_re.match(url)
if matches:
hostname = matches.group(1)
media_id = matches.group(2)
if hostname == self.hs.hostname:
local_media_mxcs.append(media_id)
else:
remote_media_mxcs.append((hostname, media_id))
return local_media_mxcs, remote_media_mxcs
@defer.inlineCallbacks @defer.inlineCallbacks
def get_rooms_for_retention_period_in_range( def get_rooms_for_retention_period_in_range(
self, min_ms, max_ms, include_null=False self, min_ms, max_ms, include_null=False