diff --git a/changelog.d/3725.misc b/changelog.d/3725.misc new file mode 100644 index 0000000000..91ab9d7137 --- /dev/null +++ b/changelog.d/3725.misc @@ -0,0 +1 @@ +The synapse.storage module has been ported to Python 3. diff --git a/jenkins/prepare_synapse.sh b/jenkins/prepare_synapse.sh index a30179f2aa..d95ca846c4 100755 --- a/jenkins/prepare_synapse.sh +++ b/jenkins/prepare_synapse.sh @@ -31,5 +31,5 @@ $TOX_BIN/pip install 'setuptools>=18.5' $TOX_BIN/pip install 'pip>=10' { python synapse/python_dependencies.py - echo lxml psycopg2 + echo lxml } | xargs $TOX_BIN/pip install diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 9c55e79ef5..942d7c721f 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -78,6 +78,9 @@ CONDITIONAL_REQUIREMENTS = { "affinity": { "affinity": ["affinity"], }, + "postgres": { + "psycopg2>=2.6": ["psycopg2"] + } } diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 08dffd774f..be61147b9b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,9 +17,10 @@ import sys import threading import time -from six import iteritems, iterkeys, itervalues +from six import PY2, iteritems, iterkeys, itervalues from six.moves import intern, range +from canonicaljson import json from prometheus_client import Histogram from twisted.internet import defer @@ -1216,3 +1217,32 @@ class _RollbackButIsFineException(Exception): something went wrong. """ pass + + +def db_to_json(db_content): + """ + Take some data from a database row and return a JSON-decoded object. + + Args: + db_content (memoryview|buffer|bytes|bytearray|unicode) + """ + # psycopg2 on Python 3 returns memoryview objects, which we need to + # cast to bytes to decode + if isinstance(db_content, memoryview): + db_content = db_content.tobytes() + + # psycopg2 on Python 2 returns buffer objects, which we need to cast to + # bytes to decode + if PY2 and isinstance(db_content, buffer): + db_content = bytes(db_content) + + # Decode it to a Unicode string before feeding it to json.loads, so we + # consistenty get a Unicode-containing object out. + if isinstance(db_content, (bytes, bytearray)): + db_content = db_content.decode('utf8') + + try: + return json.loads(db_content) + except Exception: + logging.warning("Tried to decode '%r' as JSON and failed", db_content) + raise diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 73646da025..e06b0bc56d 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -169,7 +169,7 @@ class DeviceInboxStore(BackgroundUpdateStore): local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): messages_json_for_user = {} - devices = messages_by_device.keys() + devices = list(messages_by_device.keys()) if len(devices) == 1 and devices[0] == "*": # Handle wildcard device_ids. sql = ( diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index c0943ecf91..d10ff9e4b9 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -24,7 +24,7 @@ from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList -from ._base import Cache, SQLBaseStore +from ._base import Cache, SQLBaseStore, db_to_json logger = logging.getLogger(__name__) @@ -411,7 +411,7 @@ class DeviceStore(SQLBaseStore): if device is not None: key_json = device.get("key_json", None) if key_json: - result["keys"] = json.loads(key_json) + result["keys"] = db_to_json(key_json) device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name @@ -466,7 +466,7 @@ class DeviceStore(SQLBaseStore): retcol="content", desc="_get_cached_user_device", ) - defer.returnValue(json.loads(content)) + defer.returnValue(db_to_json(content)) @cachedInlineCallbacks() def _get_cached_devices_for_user(self, user_id): @@ -479,7 +479,7 @@ class DeviceStore(SQLBaseStore): desc="_get_cached_devices_for_user", ) defer.returnValue({ - device["device_id"]: json.loads(device["content"]) + device["device_id"]: db_to_json(device["content"]) for device in devices }) @@ -511,7 +511,7 @@ class DeviceStore(SQLBaseStore): key_json = device.get("key_json", None) if key_json: - result["keys"] = json.loads(key_json) + result["keys"] = db_to_json(key_json) device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 523b4360c3..1f1721e820 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -14,13 +14,13 @@ # limitations under the License. from six import iteritems -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.util.caches.descriptors import cached -from ._base import SQLBaseStore +from ._base import SQLBaseStore, db_to_json class EndToEndKeyStore(SQLBaseStore): @@ -90,7 +90,7 @@ class EndToEndKeyStore(SQLBaseStore): for user_id, device_keys in iteritems(results): for device_id, device_info in iteritems(device_keys): - device_info["keys"] = json.loads(device_info.pop("key_json")) + device_info["keys"] = db_to_json(device_info.pop("key_json")) defer.returnValue(results) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 8a0386c1a4..42225f8a2a 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -41,13 +41,18 @@ class PostgresEngine(object): db_conn.set_isolation_level( self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) + + # Set the bytea output to escape, vs the default of hex + cursor = db_conn.cursor() + cursor.execute("SET bytea_output TO escape") + # Asynchronous commit, don't wait for the server to call fsync before # ending the transaction. # https://www.postgresql.org/docs/current/static/wal-async-commit.html if not self.synchronous_commit: - cursor = db_conn.cursor() cursor.execute("SET synchronous_commit TO OFF") - cursor.close() + + cursor.close() def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index f39c8c8461..8bf87f38f7 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -19,7 +19,7 @@ import logging from collections import OrderedDict, deque, namedtuple from functools import wraps -from six import iteritems +from six import iteritems, text_type from six.moves import range from canonicaljson import json @@ -1220,7 +1220,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore "sender": event.sender, "contains_url": ( "url" in event.content - and isinstance(event.content["url"], basestring) + and isinstance(event.content["url"], text_type) ), } for event, _ in events_and_contexts @@ -1529,7 +1529,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore contains_url = "url" in content if contains_url: - contains_url &= isinstance(content["url"], basestring) + contains_url &= isinstance(content["url"], text_type) except (KeyError, AttributeError): # If the event is missing a necessary field then # skip over it. @@ -1910,9 +1910,9 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore (room_id,) ) rows = txn.fetchall() - max_depth = max(row[0] for row in rows) + max_depth = max(row[1] for row in rows) - if max_depth <= token.topological: + if max_depth < token.topological: # We need to ensure we don't delete all the events from the database # otherwise we wouldn't be able to send any events (due to not # having any backwards extremeties) diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 59822178ff..a8326f5296 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import itertools import logging from collections import namedtuple @@ -265,7 +266,7 @@ class EventsWorkerStore(SQLBaseStore): """ with Measure(self._clock, "_fetch_event_list"): try: - event_id_lists = zip(*event_list)[0] + event_id_lists = list(zip(*event_list))[0] event_ids = [ item for sublist in event_id_lists for item in sublist ] @@ -299,14 +300,14 @@ class EventsWorkerStore(SQLBaseStore): logger.exception("do_fetch") # We only want to resolve deferreds from the main thread - def fire(evs): + def fire(evs, exc): for _, d in evs: if not d.called: with PreserveLoggingContext(): - d.errback(e) + d.errback(exc) with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire, event_list) + self.hs.get_reactor().callFromThread(fire, event_list, e) @defer.inlineCallbacks def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index 2d5896c5b4..6ddcc909bf 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.api.errors import Codes, SynapseError from synapse.util.caches.descriptors import cachedInlineCallbacks -from ._base import SQLBaseStore +from ._base import SQLBaseStore, db_to_json class FilteringStore(SQLBaseStore): @@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore): desc="get_user_filter", ) - defer.returnValue(json.loads(bytes(def_json).decode("utf-8"))) + defer.returnValue(db_to_json(def_json)) def add_user_filter(self, user_localpart, user_filter): def_json = encode_canonical_json(user_filter) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 8443bd4c1b..c7987bfcdd 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -15,7 +15,8 @@ # limitations under the License. import logging -import types + +import six from canonicaljson import encode_canonical_json, json @@ -27,6 +28,11 @@ from ._base import SQLBaseStore logger = logging.getLogger(__name__) +if six.PY2: + db_binary_type = buffer +else: + db_binary_type = memoryview + class PusherWorkerStore(SQLBaseStore): def _decode_pushers_rows(self, rows): @@ -34,18 +40,18 @@ class PusherWorkerStore(SQLBaseStore): dataJson = r['data'] r['data'] = None try: - if isinstance(dataJson, types.BufferType): + if isinstance(dataJson, db_binary_type): dataJson = str(dataJson).decode("UTF8") r['data'] = json.loads(dataJson) except Exception as e: logger.warn( "Invalid JSON in data for pusher %d: %s, %s", - r['id'], dataJson, e.message, + r['id'], dataJson, e.args[0], ) pass - if isinstance(r['pushkey'], types.BufferType): + if isinstance(r['pushkey'], db_binary_type): r['pushkey'] = str(r['pushkey']).decode("UTF8") return rows diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 428e7fa36e..0c42bd3322 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -18,14 +18,14 @@ from collections import namedtuple import six -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches.descriptors import cached -from ._base import SQLBaseStore +from ._base import SQLBaseStore, db_to_json # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview @@ -95,7 +95,8 @@ class TransactionStore(SQLBaseStore): ) if result and result["response_code"]: - return result["response_code"], json.loads(str(result["response_json"])) + return result["response_code"], db_to_json(result["response_json"]) + else: return None diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 40dc4ea256..530dc8ba6d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -240,7 +240,6 @@ class RestHelper(object): self.assertEquals(200, code) defer.returnValue(response) - @defer.inlineCallbacks def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -248,9 +247,16 @@ class RestHelper(object): body = "body_text_here" path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id) - content = '{"msgtype":"m.text","body":"%s"}' % body + content = {"msgtype": "m.text", "body": body} if tok: path = path + "?access_token=%s" % tok - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(expect_code, code, msg=str(response)) + request, channel = make_request("PUT", path, json.dumps(content).encode('utf8')) + render(request, self.resource, self.hs.get_reactor()) + + assert int(channel.result["code"]) == expect_code, ( + "Expected: %d, got: %d, resp: %r" + % (expect_code, int(channel.result["code"]), channel.result["body"]) + ) + + return channel.json_body diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py new file mode 100644 index 0000000000..f671599cb8 --- /dev/null +++ b/tests/storage/test_purge.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest.client.v1 import room + +from tests.unittest import HomeserverTestCase + + +class PurgeTests(HomeserverTestCase): + + user_id = "@red:server" + servlets = [room.register_servlets] + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver("server", http_client=None) + return hs + + def prepare(self, reactor, clock, hs): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_purge(self): + """ + Purging a room will delete everything before the topological point. + """ + # Send four messages to the room + first = self.helper.send(self.room_id, body="test1") + second = self.helper.send(self.room_id, body="test2") + third = self.helper.send(self.room_id, body="test3") + last = self.helper.send(self.room_id, body="test4") + + storage = self.hs.get_datastore() + + # Get the topological token + event = storage.get_topological_token_for_event(last["event_id"]) + self.pump() + event = self.successResultOf(event) + + # Purge everything before this topological token + purge = storage.purge_history(self.room_id, event, True) + self.pump() + self.assertEqual(self.successResultOf(purge), None) + + # Try and get the events + get_first = storage.get_event(first["event_id"]) + get_second = storage.get_event(second["event_id"]) + get_third = storage.get_event(third["event_id"]) + get_last = storage.get_event(last["event_id"]) + self.pump() + + # 1-3 should fail and last will succeed, meaning that 1-3 are deleted + # and last is not. + self.failureResultOf(get_first) + self.failureResultOf(get_second) + self.failureResultOf(get_third) + self.successResultOf(get_last) + + def test_purge_wont_delete_extrems(self): + """ + Purging a room will delete everything before the topological point. + """ + # Send four messages to the room + first = self.helper.send(self.room_id, body="test1") + second = self.helper.send(self.room_id, body="test2") + third = self.helper.send(self.room_id, body="test3") + last = self.helper.send(self.room_id, body="test4") + + storage = self.hs.get_datastore() + + # Set the topological token higher than it should be + event = storage.get_topological_token_for_event(last["event_id"]) + self.pump() + event = self.successResultOf(event) + event = "t{}-{}".format( + *list(map(lambda x: x + 1, map(int, event[1:].split("-")))) + ) + + # Purge everything before this topological token + purge = storage.purge_history(self.room_id, event, True) + self.pump() + f = self.failureResultOf(purge) + self.assertIn("greater than forward", f.value.args[0]) + + # Try and get the events + get_first = storage.get_event(first["event_id"]) + get_second = storage.get_event(second["event_id"]) + get_third = storage.get_event(third["event_id"]) + get_last = storage.get_event(last["event_id"]) + self.pump() + + # Nothing is deleted. + self.successResultOf(get_first) + self.successResultOf(get_second) + self.successResultOf(get_third) + self.successResultOf(get_last) diff --git a/tests/unittest.py b/tests/unittest.py index d852e2465a..8b513bb32b 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -151,6 +151,7 @@ class HomeserverTestCase(TestCase): hijack_auth (bool): Whether to hijack auth to return the user specified in user_id. """ + servlets = [] hijack_auth = True @@ -279,3 +280,13 @@ class HomeserverTestCase(TestCase): kwargs = dict(kwargs) kwargs.update(self._hs_args) return setup_test_homeserver(self.addCleanup, *args, **kwargs) + + def pump(self): + """ + Pump the reactor enough that Deferreds will fire. + """ + self.reactor.pump([0.0] * 100) + + def get_success(self, d): + self.pump() + return self.successResultOf(d) diff --git a/tests/utils.py b/tests/utils.py index 657c258a6e..179b592501 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -147,6 +147,8 @@ def setup_test_homeserver( config.max_mau_value = 50 config.mau_limits_reserved_threepids = [] config.admin_contact = None + config.rc_messages_per_second = 10000 + config.rc_message_burst_count = 10000 # we need a sane default_room_version, otherwise attempts to create rooms will # fail.