diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py index 01b1f13d82..a9a28797b1 100644 --- a/tests/rest/client/test_sendtodevice.py +++ b/tests/rest/client/test_sendtodevice.py @@ -1,261 +1,13 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2021 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# . -# -# Originally licensed under the Apache License, Version 2.0: -# . -# -# [This file includes modifications made by New Vector Limited] -# -# - from twisted.test.proto_helpers import MemoryReactor -from synapse.api.constants import EduTypes -from synapse.rest import admin -from synapse.rest.client import login, sendtodevice, sync from synapse.server import HomeServer +from synapse.types import JsonDict from synapse.util import Clock -from tests.unittest import HomeserverTestCase, override_config +from tests.rest.client.test_sendtodevice_base import SendToDeviceTestCaseBase +from tests.unittest import HomeserverTestCase -class SendToDeviceTestCase(HomeserverTestCase): - servlets = [ - admin.register_servlets, - login.register_servlets, - sendtodevice.register_servlets, - sync.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sync_endpoint = "/sync" - - def test_user_to_user(self) -> None: - """A to-device message from one user to another should get delivered""" - - user1 = self.register_user("u1", "pass") - user1_tok = self.login("u1", "pass", "d1") - - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - # send the message - test_msg = {"foo": "bar"} - chan = self.make_request( - "PUT", - "/_matrix/client/r0/sendToDevice/m.test/1234", - content={"messages": {user2: {"d2": test_msg}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - # check it appears - channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - expected_result = { - "events": [ - { - "sender": user1, - "type": "m.test", - "content": test_msg, - } - ] - } - self.assertEqual(channel.json_body["to_device"], expected_result) - - # it should re-appear if we do another sync because the to-device message is not - # deleted until we acknowledge it by sending a `?since=...` parameter in the - # next sync request corresponding to the `next_batch` value from the response. - channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - self.assertEqual(channel.json_body["to_device"], expected_result) - - # it should *not* appear if we do an incremental sync - sync_token = channel.json_body["next_batch"] - channel = self.make_request( - "GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), []) - - @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_local_room_key_request(self) -> None: - """m.room_key_request has special-casing; test from local user""" - user1 = self.register_user("u1", "pass") - user1_tok = self.login("u1", "pass", "d1") - - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - # send three messages - for i in range(3): - chan = self.make_request( - "PUT", - f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}", - content={"messages": {user2: {"d2": {"idx": i}}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - # now sync: we should get two of the three (because burst_count=2) - channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 2) - for i in range(2): - self.assertEqual( - msgs[i], - {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}}, - ) - sync_token = channel.json_body["next_batch"] - - # ... time passes - self.reactor.advance(1) - - # and we can send more messages - chan = self.make_request( - "PUT", - "/_matrix/client/r0/sendToDevice/m.room_key_request/3", - content={"messages": {user2: {"d2": {"idx": 3}}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - # ... which should arrive - channel = self.make_request( - "GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 1) - self.assertEqual( - msgs[0], - {"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}}, - ) - - @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_remote_room_key_request(self) -> None: - """m.room_key_request has special-casing; test from remote user""" - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - federation_registry = self.hs.get_federation_registry() - - # send three messages - for i in range(3): - self.get_success( - federation_registry.on_edu( - EduTypes.DIRECT_TO_DEVICE, - "remote_server", - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "messages": {user2: {"d2": {"idx": i}}}, - "message_id": f"{i}", - }, - ) - ) - - # now sync: we should get two of the three - channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 2) - for i in range(2): - self.assertEqual( - msgs[i], - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "content": {"idx": i}, - }, - ) - sync_token = channel.json_body["next_batch"] - - # ... time passes - self.reactor.advance(1) - - # and we can send more messages - self.get_success( - federation_registry.on_edu( - EduTypes.DIRECT_TO_DEVICE, - "remote_server", - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "messages": {user2: {"d2": {"idx": 3}}}, - "message_id": "3", - }, - ) - ) - - # ... which should arrive - channel = self.make_request( - "GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 1) - self.assertEqual( - msgs[0], - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "content": {"idx": 3}, - }, - ) - - def test_limited_sync(self) -> None: - """If a limited sync for to-devices happens the next /sync should respond immediately.""" - - self.register_user("u1", "pass") - user1_tok = self.login("u1", "pass", "d1") - - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - # Do an initial sync - channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - sync_token = channel.json_body["next_batch"] - - # Send 150 to-device messages. We limit to 100 in `/sync` - for i in range(150): - test_msg = {"foo": "bar"} - chan = self.make_request( - "PUT", - f"/_matrix/client/r0/sendToDevice/m.test/1234-{i}", - content={"messages": {user2: {"d2": test_msg}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={sync_token}&timeout=300000", - access_token=user2_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - messages = channel.json_body.get("to_device", {}).get("events", []) - self.assertEqual(len(messages), 100) - sync_token = channel.json_body["next_batch"] - - channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={sync_token}&timeout=300000", - access_token=user2_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - messages = channel.json_body.get("to_device", {}).get("events", []) - self.assertEqual(len(messages), 50) +class SendToDeviceTestCase(SendToDeviceTestCaseBase, HomeserverTestCase): + # See SendToDeviceTestCaseBase for tests + pass diff --git a/tests/rest/client/test_sendtodevice_base.py b/tests/rest/client/test_sendtodevice_base.py new file mode 100644 index 0000000000..62c6364a59 --- /dev/null +++ b/tests/rest/client/test_sendtodevice_base.py @@ -0,0 +1,268 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright 2021 The Matrix.org Foundation C.I.C. +# Copyright (C) 2023 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# Originally licensed under the Apache License, Version 2.0: +# . +# +# [This file includes modifications made by New Vector Limited] +# +# + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EduTypes +from synapse.rest import admin +from synapse.rest.client import login, sendtodevice, sync +from synapse.server import HomeServer +from synapse.util import Clock + +from tests.unittest import override_config + + +class SendToDeviceTestCaseBase: + """ + Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`. + + In order to run the tests, inherit from this base-class with `HomeserverTestCase`, e.g. + `class SendToDeviceTestCase(SendToDeviceTestCase, HomeserverTestCase)` + """ + + servlets = [ + admin.register_servlets, + login.register_servlets, + sendtodevice.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.sync_endpoint = "/sync" + + def test_user_to_user(self) -> None: + """A to-device message from one user to another should get delivered""" + + user1 = self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") + + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + # send the message + test_msg = {"foo": "bar"} + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.test/1234", + content={"messages": {user2: {"d2": test_msg}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + # check it appears + channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + expected_result = { + "events": [ + { + "sender": user1, + "type": "m.test", + "content": test_msg, + } + ] + } + self.assertEqual(channel.json_body["to_device"], expected_result) + + # it should re-appear if we do another sync because the to-device message is not + # deleted until we acknowledge it by sending a `?since=...` parameter in the + # next sync request corresponding to the `next_batch` value from the response. + channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["to_device"], expected_result) + + # it should *not* appear if we do an incremental sync + sync_token = channel.json_body["next_batch"] + channel = self.make_request( + "GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), []) + + @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) + def test_local_room_key_request(self) -> None: + """m.room_key_request has special-casing; test from local user""" + user1 = self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") + + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + # send three messages + for i in range(3): + chan = self.make_request( + "PUT", + f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}", + content={"messages": {user2: {"d2": {"idx": i}}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + # now sync: we should get two of the three (because burst_count=2) + channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 2) + for i in range(2): + self.assertEqual( + msgs[i], + {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}}, + ) + sync_token = channel.json_body["next_batch"] + + # ... time passes + self.reactor.advance(1) + + # and we can send more messages + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.room_key_request/3", + content={"messages": {user2: {"d2": {"idx": 3}}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + # ... which should arrive + channel = self.make_request( + "GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 1) + self.assertEqual( + msgs[0], + {"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}}, + ) + + @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) + def test_remote_room_key_request(self) -> None: + """m.room_key_request has special-casing; test from remote user""" + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + federation_registry = self.hs.get_federation_registry() + + # send three messages + for i in range(3): + self.get_success( + federation_registry.on_edu( + EduTypes.DIRECT_TO_DEVICE, + "remote_server", + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "messages": {user2: {"d2": {"idx": i}}}, + "message_id": f"{i}", + }, + ) + ) + + # now sync: we should get two of the three + channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 2) + for i in range(2): + self.assertEqual( + msgs[i], + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "content": {"idx": i}, + }, + ) + sync_token = channel.json_body["next_batch"] + + # ... time passes + self.reactor.advance(1) + + # and we can send more messages + self.get_success( + federation_registry.on_edu( + EduTypes.DIRECT_TO_DEVICE, + "remote_server", + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "messages": {user2: {"d2": {"idx": 3}}}, + "message_id": "3", + }, + ) + ) + + # ... which should arrive + channel = self.make_request( + "GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 1) + self.assertEqual( + msgs[0], + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "content": {"idx": 3}, + }, + ) + + def test_limited_sync(self) -> None: + """If a limited sync for to-devices happens the next /sync should respond immediately.""" + + self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") + + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + # Do an initial sync + channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + sync_token = channel.json_body["next_batch"] + + # Send 150 to-device messages. We limit to 100 in `/sync` + for i in range(150): + test_msg = {"foo": "bar"} + chan = self.make_request( + "PUT", + f"/_matrix/client/r0/sendToDevice/m.test/1234-{i}", + content={"messages": {user2: {"d2": test_msg}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={sync_token}&timeout=300000", + access_token=user2_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + messages = channel.json_body.get("to_device", {}).get("events", []) + self.assertEqual(len(messages), 100) + sync_token = channel.json_body["next_batch"] + + channel = self.make_request( + "GET", + f"{self.sync_endpoint}?since={sync_token}&timeout=300000", + access_token=user2_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + messages = channel.json_body.get("to_device", {}).get("events", []) + self.assertEqual(len(messages), 50) diff --git a/tests/rest/client/test_sliding_sync.py b/tests/rest/client/test_sliding_sync.py index 46dde59e86..70a4375637 100644 --- a/tests/rest/client/test_sliding_sync.py +++ b/tests/rest/client/test_sliding_sync.py @@ -4,11 +4,12 @@ from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock -from tests.rest.client.test_sendtodevice import SendToDeviceTestCase +from tests.rest.client.test_sendtodevice_base import SendToDeviceTestCaseBase +from tests.unittest import HomeserverTestCase # Test To-Device messages working correctly with the `/sync/e2ee` endpoint (`to_device`) -class SlidingSyncE2eeSendToDeviceTestCase(SendToDeviceTestCase): +class SlidingSyncE2eeSendToDeviceTestCase(SendToDeviceTestCaseBase, HomeserverTestCase): def default_config(self) -> JsonDict: config = super().default_config() # Enable sliding sync @@ -19,4 +20,4 @@ class SlidingSyncE2eeSendToDeviceTestCase(SendToDeviceTestCase): # Use the Sliding Sync `/sync/e2ee` endpoint self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee" - # See SendToDeviceTestCase for tests + # See SendToDeviceTestCaseBase for tests