Port storage/ to Python 3 (#3725)

This commit is contained in:
Amber Brown 2018-08-31 00:19:58 +10:00 committed by GitHub
parent 475253a88e
commit 14e4d4f4bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 208 additions and 36 deletions

1
changelog.d/3725.misc Normal file
View file

@ -0,0 +1 @@
The synapse.storage module has been ported to Python 3.

View file

@ -31,5 +31,5 @@ $TOX_BIN/pip install 'setuptools>=18.5'
$TOX_BIN/pip install 'pip>=10' $TOX_BIN/pip install 'pip>=10'
{ python synapse/python_dependencies.py { python synapse/python_dependencies.py
echo lxml psycopg2 echo lxml
} | xargs $TOX_BIN/pip install } | xargs $TOX_BIN/pip install

View file

@ -78,6 +78,9 @@ CONDITIONAL_REQUIREMENTS = {
"affinity": { "affinity": {
"affinity": ["affinity"], "affinity": ["affinity"],
}, },
"postgres": {
"psycopg2>=2.6": ["psycopg2"]
}
} }

View file

@ -17,9 +17,10 @@ import sys
import threading import threading
import time import time
from six import iteritems, iterkeys, itervalues from six import PY2, iteritems, iterkeys, itervalues
from six.moves import intern, range from six.moves import intern, range
from canonicaljson import json
from prometheus_client import Histogram from prometheus_client import Histogram
from twisted.internet import defer from twisted.internet import defer
@ -1216,3 +1217,32 @@ class _RollbackButIsFineException(Exception):
something went wrong. something went wrong.
""" """
pass pass
def db_to_json(db_content):
"""
Take some data from a database row and return a JSON-decoded object.
Args:
db_content (memoryview|buffer|bytes|bytearray|unicode)
"""
# psycopg2 on Python 3 returns memoryview objects, which we need to
# cast to bytes to decode
if isinstance(db_content, memoryview):
db_content = db_content.tobytes()
# psycopg2 on Python 2 returns buffer objects, which we need to cast to
# bytes to decode
if PY2 and isinstance(db_content, buffer):
db_content = bytes(db_content)
# Decode it to a Unicode string before feeding it to json.loads, so we
# consistenty get a Unicode-containing object out.
if isinstance(db_content, (bytes, bytearray)):
db_content = db_content.decode('utf8')
try:
return json.loads(db_content)
except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise

View file

@ -169,7 +169,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
local_by_user_then_device = {} local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items(): for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {} messages_json_for_user = {}
devices = messages_by_device.keys() devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*": if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids. # Handle wildcard device_ids.
sql = ( sql = (

View file

@ -24,7 +24,7 @@ from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from ._base import Cache, SQLBaseStore from ._base import Cache, SQLBaseStore, db_to_json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -411,7 +411,7 @@ class DeviceStore(SQLBaseStore):
if device is not None: if device is not None:
key_json = device.get("key_json", None) key_json = device.get("key_json", None)
if key_json: if key_json:
result["keys"] = json.loads(key_json) result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None) device_display_name = device.get("device_display_name", None)
if device_display_name: if device_display_name:
result["device_display_name"] = device_display_name result["device_display_name"] = device_display_name
@ -466,7 +466,7 @@ class DeviceStore(SQLBaseStore):
retcol="content", retcol="content",
desc="_get_cached_user_device", desc="_get_cached_user_device",
) )
defer.returnValue(json.loads(content)) defer.returnValue(db_to_json(content))
@cachedInlineCallbacks() @cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id): def _get_cached_devices_for_user(self, user_id):
@ -479,7 +479,7 @@ class DeviceStore(SQLBaseStore):
desc="_get_cached_devices_for_user", desc="_get_cached_devices_for_user",
) )
defer.returnValue({ defer.returnValue({
device["device_id"]: json.loads(device["content"]) device["device_id"]: db_to_json(device["content"])
for device in devices for device in devices
}) })
@ -511,7 +511,7 @@ class DeviceStore(SQLBaseStore):
key_json = device.get("key_json", None) key_json = device.get("key_json", None)
if key_json: if key_json:
result["keys"] = json.loads(key_json) result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None) device_display_name = device.get("device_display_name", None)
if device_display_name: if device_display_name:
result["device_display_name"] = device_display_name result["device_display_name"] = device_display_name

View file

@ -14,13 +14,13 @@
# limitations under the License. # limitations under the License.
from six import iteritems from six import iteritems
from canonicaljson import encode_canonical_json, json from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore from ._base import SQLBaseStore, db_to_json
class EndToEndKeyStore(SQLBaseStore): class EndToEndKeyStore(SQLBaseStore):
@ -90,7 +90,7 @@ class EndToEndKeyStore(SQLBaseStore):
for user_id, device_keys in iteritems(results): for user_id, device_keys in iteritems(results):
for device_id, device_info in iteritems(device_keys): for device_id, device_info in iteritems(device_keys):
device_info["keys"] = json.loads(device_info.pop("key_json")) device_info["keys"] = db_to_json(device_info.pop("key_json"))
defer.returnValue(results) defer.returnValue(results)

View file

@ -41,13 +41,18 @@ class PostgresEngine(object):
db_conn.set_isolation_level( db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
) )
# Set the bytea output to escape, vs the default of hex
cursor = db_conn.cursor()
cursor.execute("SET bytea_output TO escape")
# Asynchronous commit, don't wait for the server to call fsync before # Asynchronous commit, don't wait for the server to call fsync before
# ending the transaction. # ending the transaction.
# https://www.postgresql.org/docs/current/static/wal-async-commit.html # https://www.postgresql.org/docs/current/static/wal-async-commit.html
if not self.synchronous_commit: if not self.synchronous_commit:
cursor = db_conn.cursor()
cursor.execute("SET synchronous_commit TO OFF") cursor.execute("SET synchronous_commit TO OFF")
cursor.close()
cursor.close()
def is_deadlock(self, error): def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError): if isinstance(error, self.module.DatabaseError):

View file

@ -19,7 +19,7 @@ import logging
from collections import OrderedDict, deque, namedtuple from collections import OrderedDict, deque, namedtuple
from functools import wraps from functools import wraps
from six import iteritems from six import iteritems, text_type
from six.moves import range from six.moves import range
from canonicaljson import json from canonicaljson import json
@ -1220,7 +1220,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
"sender": event.sender, "sender": event.sender,
"contains_url": ( "contains_url": (
"url" in event.content "url" in event.content
and isinstance(event.content["url"], basestring) and isinstance(event.content["url"], text_type)
), ),
} }
for event, _ in events_and_contexts for event, _ in events_and_contexts
@ -1529,7 +1529,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
contains_url = "url" in content contains_url = "url" in content
if contains_url: if contains_url:
contains_url &= isinstance(content["url"], basestring) contains_url &= isinstance(content["url"], text_type)
except (KeyError, AttributeError): except (KeyError, AttributeError):
# If the event is missing a necessary field then # If the event is missing a necessary field then
# skip over it. # skip over it.
@ -1910,9 +1910,9 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
(room_id,) (room_id,)
) )
rows = txn.fetchall() rows = txn.fetchall()
max_depth = max(row[0] for row in rows) max_depth = max(row[1] for row in rows)
if max_depth <= token.topological: if max_depth < token.topological:
# We need to ensure we don't delete all the events from the database # We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not # otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties) # having any backwards extremeties)

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import itertools import itertools
import logging import logging
from collections import namedtuple from collections import namedtuple
@ -265,7 +266,7 @@ class EventsWorkerStore(SQLBaseStore):
""" """
with Measure(self._clock, "_fetch_event_list"): with Measure(self._clock, "_fetch_event_list"):
try: try:
event_id_lists = zip(*event_list)[0] event_id_lists = list(zip(*event_list))[0]
event_ids = [ event_ids = [
item for sublist in event_id_lists for item in sublist item for sublist in event_id_lists for item in sublist
] ]
@ -299,14 +300,14 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch") logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread # We only want to resolve deferreds from the main thread
def fire(evs): def fire(evs, exc):
for _, d in evs: for _, d in evs:
if not d.called: if not d.called:
with PreserveLoggingContext(): with PreserveLoggingContext():
d.errback(e) d.errback(exc)
with PreserveLoggingContext(): with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list) self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks @defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):

View file

@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from canonicaljson import encode_canonical_json, json from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from ._base import SQLBaseStore from ._base import SQLBaseStore, db_to_json
class FilteringStore(SQLBaseStore): class FilteringStore(SQLBaseStore):
@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter", desc="get_user_filter",
) )
defer.returnValue(json.loads(bytes(def_json).decode("utf-8"))) defer.returnValue(db_to_json(def_json))
def add_user_filter(self, user_localpart, user_filter): def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter) def_json = encode_canonical_json(user_filter)

View file

@ -15,7 +15,8 @@
# limitations under the License. # limitations under the License.
import logging import logging
import types
import six
from canonicaljson import encode_canonical_json, json from canonicaljson import encode_canonical_json, json
@ -27,6 +28,11 @@ from ._base import SQLBaseStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if six.PY2:
db_binary_type = buffer
else:
db_binary_type = memoryview
class PusherWorkerStore(SQLBaseStore): class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows): def _decode_pushers_rows(self, rows):
@ -34,18 +40,18 @@ class PusherWorkerStore(SQLBaseStore):
dataJson = r['data'] dataJson = r['data']
r['data'] = None r['data'] = None
try: try:
if isinstance(dataJson, types.BufferType): if isinstance(dataJson, db_binary_type):
dataJson = str(dataJson).decode("UTF8") dataJson = str(dataJson).decode("UTF8")
r['data'] = json.loads(dataJson) r['data'] = json.loads(dataJson)
except Exception as e: except Exception as e:
logger.warn( logger.warn(
"Invalid JSON in data for pusher %d: %s, %s", "Invalid JSON in data for pusher %d: %s, %s",
r['id'], dataJson, e.message, r['id'], dataJson, e.args[0],
) )
pass pass
if isinstance(r['pushkey'], types.BufferType): if isinstance(r['pushkey'], db_binary_type):
r['pushkey'] = str(r['pushkey']).decode("UTF8") r['pushkey'] = str(r['pushkey']).decode("UTF8")
return rows return rows

View file

@ -18,14 +18,14 @@ from collections import namedtuple
import six import six
from canonicaljson import encode_canonical_json, json from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore from ._base import SQLBaseStore, db_to_json
# py2 sqlite has buffer hardcoded as only binary type, so we must use it, # py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview # despite being deprecated and removed in favor of memoryview
@ -95,7 +95,8 @@ class TransactionStore(SQLBaseStore):
) )
if result and result["response_code"]: if result and result["response_code"]:
return result["response_code"], json.loads(str(result["response_json"])) return result["response_code"], db_to_json(result["response_json"])
else: else:
return None return None

View file

@ -240,7 +240,6 @@ class RestHelper(object):
self.assertEquals(200, code) self.assertEquals(200, code)
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
if txn_id is None: if txn_id is None:
txn_id = "m%s" % (str(time.time())) txn_id = "m%s" % (str(time.time()))
@ -248,9 +247,16 @@ class RestHelper(object):
body = "body_text_here" body = "body_text_here"
path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id) path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
content = '{"msgtype":"m.text","body":"%s"}' % body content = {"msgtype": "m.text", "body": body}
if tok: if tok:
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
(code, response) = yield self.mock_resource.trigger("PUT", path, content) request, channel = make_request("PUT", path, json.dumps(content).encode('utf8'))
self.assertEquals(expect_code, code, msg=str(response)) render(request, self.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
% (expect_code, int(channel.result["code"]), channel.result["body"])
)
return channel.json_body

106
tests/storage/test_purge.py Normal file
View file

@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
class PurgeTests(HomeserverTestCase):
user_id = "@red:server"
servlets = [room.register_servlets]
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver("server", http_client=None)
return hs
def prepare(self, reactor, clock, hs):
self.room_id = self.helper.create_room_as(self.user_id)
def test_purge(self):
"""
Purging a room will delete everything before the topological point.
"""
# Send four messages to the room
first = self.helper.send(self.room_id, body="test1")
second = self.helper.send(self.room_id, body="test2")
third = self.helper.send(self.room_id, body="test3")
last = self.helper.send(self.room_id, body="test4")
storage = self.hs.get_datastore()
# Get the topological token
event = storage.get_topological_token_for_event(last["event_id"])
self.pump()
event = self.successResultOf(event)
# Purge everything before this topological token
purge = storage.purge_history(self.room_id, event, True)
self.pump()
self.assertEqual(self.successResultOf(purge), None)
# Try and get the events
get_first = storage.get_event(first["event_id"])
get_second = storage.get_event(second["event_id"])
get_third = storage.get_event(third["event_id"])
get_last = storage.get_event(last["event_id"])
self.pump()
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
self.failureResultOf(get_first)
self.failureResultOf(get_second)
self.failureResultOf(get_third)
self.successResultOf(get_last)
def test_purge_wont_delete_extrems(self):
"""
Purging a room will delete everything before the topological point.
"""
# Send four messages to the room
first = self.helper.send(self.room_id, body="test1")
second = self.helper.send(self.room_id, body="test2")
third = self.helper.send(self.room_id, body="test3")
last = self.helper.send(self.room_id, body="test4")
storage = self.hs.get_datastore()
# Set the topological token higher than it should be
event = storage.get_topological_token_for_event(last["event_id"])
self.pump()
event = self.successResultOf(event)
event = "t{}-{}".format(
*list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
)
# Purge everything before this topological token
purge = storage.purge_history(self.room_id, event, True)
self.pump()
f = self.failureResultOf(purge)
self.assertIn("greater than forward", f.value.args[0])
# Try and get the events
get_first = storage.get_event(first["event_id"])
get_second = storage.get_event(second["event_id"])
get_third = storage.get_event(third["event_id"])
get_last = storage.get_event(last["event_id"])
self.pump()
# Nothing is deleted.
self.successResultOf(get_first)
self.successResultOf(get_second)
self.successResultOf(get_third)
self.successResultOf(get_last)

View file

@ -151,6 +151,7 @@ class HomeserverTestCase(TestCase):
hijack_auth (bool): Whether to hijack auth to return the user specified hijack_auth (bool): Whether to hijack auth to return the user specified
in user_id. in user_id.
""" """
servlets = [] servlets = []
hijack_auth = True hijack_auth = True
@ -279,3 +280,13 @@ class HomeserverTestCase(TestCase):
kwargs = dict(kwargs) kwargs = dict(kwargs)
kwargs.update(self._hs_args) kwargs.update(self._hs_args)
return setup_test_homeserver(self.addCleanup, *args, **kwargs) return setup_test_homeserver(self.addCleanup, *args, **kwargs)
def pump(self):
"""
Pump the reactor enough that Deferreds will fire.
"""
self.reactor.pump([0.0] * 100)
def get_success(self, d):
self.pump()
return self.successResultOf(d)

View file

@ -147,6 +147,8 @@ def setup_test_homeserver(
config.max_mau_value = 50 config.max_mau_value = 50
config.mau_limits_reserved_threepids = [] config.mau_limits_reserved_threepids = []
config.admin_contact = None config.admin_contact = None
config.rc_messages_per_second = 10000
config.rc_message_burst_count = 10000
# we need a sane default_room_version, otherwise attempts to create rooms will # we need a sane default_room_version, otherwise attempts to create rooms will
# fail. # fail.