Share tests with test_sendtodevice

This commit is contained in:
Eric Eastwood 2024-05-08 10:48:59 -05:00
parent 1e05a05f03
commit 5e925f621c
4 changed files with 34 additions and 76 deletions

View file

@ -18,9 +18,9 @@
# [This file includes modifications made by New Vector Limited] # [This file includes modifications made by New Vector Limited]
# #
# #
from enum import Enum
import itertools import itertools
import logging import logging
from enum import Enum
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
AbstractSet, AbstractSet,

View file

@ -20,8 +20,8 @@
# #
import itertools import itertools
import logging import logging
from collections import defaultdict
import re import re
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState

View file

@ -19,9 +19,13 @@
# #
# #
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes from synapse.api.constants import EduTypes
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync from synapse.rest.client import login, sendtodevice, sync
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
@ -34,6 +38,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync.register_servlets, sync.register_servlets,
] ]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_endpoint = "/sync"
def test_user_to_user(self) -> None: def test_user_to_user(self) -> None:
"""A to-device message from one user to another should get delivered""" """A to-device message from one user to another should get delivered"""
@ -54,7 +61,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.code, 200, chan.result)
# check it appears # check it appears
channel = self.make_request("GET", "/sync", access_token=user2_tok) channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
expected_result = { expected_result = {
"events": [ "events": [
@ -70,14 +77,14 @@ class SendToDeviceTestCase(HomeserverTestCase):
# it should re-appear if we do another sync because the to-device message is not # it should re-appear if we do another sync because the to-device message is not
# deleted until we acknowledge it by sending a `?since=...` parameter in the # deleted until we acknowledge it by sending a `?since=...` parameter in the
# next sync request corresponding to the `next_batch` value from the response. # next sync request corresponding to the `next_batch` value from the response.
channel = self.make_request("GET", "/sync", access_token=user2_tok) channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result) self.assertEqual(channel.json_body["to_device"], expected_result)
# it should *not* appear if we do an incremental sync # it should *not* appear if we do an incremental sync
sync_token = channel.json_body["next_batch"] sync_token = channel.json_body["next_batch"]
channel = self.make_request( channel = self.make_request(
"GET", f"/sync?since={sync_token}", access_token=user2_tok "GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok
) )
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), []) self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
@ -102,7 +109,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.code, 200, chan.result)
# now sync: we should get two of the three (because burst_count=2) # now sync: we should get two of the three (because burst_count=2)
channel = self.make_request("GET", "/sync", access_token=user2_tok) channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"] msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2) self.assertEqual(len(msgs), 2)
@ -127,7 +134,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive # ... which should arrive
channel = self.make_request( channel = self.make_request(
"GET", f"/sync?since={sync_token}", access_token=user2_tok "GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok
) )
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"] msgs = channel.json_body["to_device"]["events"]
@ -161,7 +168,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
) )
# now sync: we should get two of the three # now sync: we should get two of the three
channel = self.make_request("GET", "/sync", access_token=user2_tok) channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"] msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2) self.assertEqual(len(msgs), 2)
@ -195,7 +202,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive # ... which should arrive
channel = self.make_request( channel = self.make_request(
"GET", f"/sync?since={sync_token}", access_token=user2_tok "GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok
) )
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"] msgs = channel.json_body["to_device"]["events"]
@ -219,7 +226,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
user2_tok = self.login("u2", "pass", "d2") user2_tok = self.login("u2", "pass", "d2")
# Do an initial sync # Do an initial sync
channel = self.make_request("GET", "/sync", access_token=user2_tok) channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
sync_token = channel.json_body["next_batch"] sync_token = channel.json_body["next_batch"]
@ -235,7 +242,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.code, 200, chan.result)
channel = self.make_request( channel = self.make_request(
"GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok "GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
access_token=user2_tok,
) )
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", []) messages = channel.json_body.get("to_device", {}).get("events", [])
@ -243,7 +252,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync_token = channel.json_body["next_batch"] sync_token = channel.json_body["next_batch"]
channel = self.make_request( channel = self.make_request(
"GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok "GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
access_token=user2_tok,
) )
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", []) messages = channel.json_body.get("to_device", {}).get("events", [])

View file

@ -1,74 +1,21 @@
from synapse.api.constants import EduTypes from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config from tests.rest.client.test_sendtodevice import SendToDeviceTestCase
class SendToDeviceTestCase(HomeserverTestCase): class SlidingSyncSendToDeviceTestCase(SendToDeviceTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
sendtodevice.register_servlets,
sync.register_servlets,
]
def default_config(self) -> JsonDict: def default_config(self) -> JsonDict:
config = super().default_config() config = super().default_config()
# Enable sliding sync
config["experimental_features"] = {"msc3575_enabled": True} config["experimental_features"] = {"msc3575_enabled": True}
return config return config
def test_user_to_user(self) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
"""A to-device message from one user to another should get delivered""" # Use the Sliding Sync `/sync/e2ee` endpoint
self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee"
user1 = self.register_user("u1", "pass") # See SendToDeviceTestCase
user1_tok = self.login("u1", "pass", "d1")
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
# send the message
test_msg = {"foo": "bar"}
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.test/1234",
content={"messages": {user2: {"d2": test_msg}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
# check it appears
channel = self.make_request("GET", "/sync", access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
expected_result = {
"events": [
{
"sender": user1,
"type": "m.test",
"content": test_msg,
}
]
}
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should re-appear if we do another sync because the to-device message is not
# deleted until we acknowledge it by sending a `?since=...` parameter in the
# next sync request corresponding to the `next_batch` value from the response.
channel = self.make_request(
"GET",
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should *not* appear if we do an incremental sync
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])