Share tests with test_sendtodevice

This commit is contained in:
Eric Eastwood 2024-05-08 10:48:59 -05:00
parent 1e05a05f03
commit 5e925f621c
4 changed files with 34 additions and 76 deletions

View file

@ -18,9 +18,9 @@
# [This file includes modifications made by New Vector Limited]
#
#
from enum import Enum
import itertools
import logging
from enum import Enum
from typing import (
TYPE_CHECKING,
AbstractSet,

View file

@ -20,8 +20,8 @@
#
import itertools
import logging
from collections import defaultdict
import re
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState

View file

@ -19,9 +19,13 @@
#
#
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config
@ -34,6 +38,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_endpoint = "/sync"
def test_user_to_user(self) -> None:
"""A to-device message from one user to another should get delivered"""
@ -54,7 +61,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
# check it appears
channel = self.make_request("GET", "/sync", access_token=user2_tok)
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
expected_result = {
"events": [
@ -70,14 +77,14 @@ class SendToDeviceTestCase(HomeserverTestCase):
# it should re-appear if we do another sync because the to-device message is not
# deleted until we acknowledge it by sending a `?since=...` parameter in the
# next sync request corresponding to the `next_batch` value from the response.
channel = self.make_request("GET", "/sync", access_token=user2_tok)
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should *not* appear if we do an incremental sync
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET", f"/sync?since={sync_token}", access_token=user2_tok
"GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
@ -102,7 +109,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
# now sync: we should get two of the three (because burst_count=2)
channel = self.make_request("GET", "/sync", access_token=user2_tok)
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
@ -127,7 +134,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
"GET", f"/sync?since={sync_token}", access_token=user2_tok
"GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
@ -161,7 +168,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
)
# now sync: we should get two of the three
channel = self.make_request("GET", "/sync", access_token=user2_tok)
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
@ -195,7 +202,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
"GET", f"/sync?since={sync_token}", access_token=user2_tok
"GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
@ -219,7 +226,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
user2_tok = self.login("u2", "pass", "d2")
# Do an initial sync
channel = self.make_request("GET", "/sync", access_token=user2_tok)
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
sync_token = channel.json_body["next_batch"]
@ -235,7 +242,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
channel = self.make_request(
"GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
"GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
@ -243,7 +252,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
"GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])

View file

@ -1,74 +1,21 @@
from synapse.api.constants import EduTypes
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config
from tests.rest.client.test_sendtodevice import SendToDeviceTestCase
class SendToDeviceTestCase(HomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
sendtodevice.register_servlets,
sync.register_servlets,
]
class SlidingSyncSendToDeviceTestCase(SendToDeviceTestCase):
def default_config(self) -> JsonDict:
config = super().default_config()
# Enable sliding sync
config["experimental_features"] = {"msc3575_enabled": True}
return config
def test_user_to_user(self) -> None:
"""A to-device message from one user to another should get delivered"""
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Use the Sliding Sync `/sync/e2ee` endpoint
self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee"
user1 = self.register_user("u1", "pass")
user1_tok = self.login("u1", "pass", "d1")
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
# send the message
test_msg = {"foo": "bar"}
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.test/1234",
content={"messages": {user2: {"d2": test_msg}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
# check it appears
channel = self.make_request("GET", "/sync", access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
expected_result = {
"events": [
{
"sender": user1,
"type": "m.test",
"content": test_msg,
}
]
}
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should re-appear if we do another sync because the to-device message is not
# deleted until we acknowledge it by sending a `?since=...` parameter in the
# next sync request corresponding to the `next_batch` value from the response.
channel = self.make_request(
"GET",
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should *not* appear if we do an incremental sync
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
# See SendToDeviceTestCase