mirror of
https://github.com/element-hq/synapse
synced 2024-07-15 14:04:07 +00:00
De-localpart {Filtering,FilteringWorkerStore}.get_user_filter()
Signed-off-by: Sean Quah <seanq@matrix.org>
This commit is contained in:
parent
06f9ababc4
commit
f98141ceb2
|
@ -165,9 +165,9 @@ class Filtering:
|
|||
self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
|
||||
|
||||
async def get_user_filter(
|
||||
self, user_localpart: str, filter_id: Union[int, str]
|
||||
self, user_id: str, filter_id: Union[int, str]
|
||||
) -> "FilterCollection":
|
||||
result = await self.store.get_user_filter(user_localpart, filter_id)
|
||||
result = await self.store.get_user_filter(user_id, filter_id)
|
||||
return FilterCollection(self._hs, result)
|
||||
|
||||
def add_user_filter(
|
||||
|
|
|
@ -58,7 +58,7 @@ class GetFilterRestServlet(RestServlet):
|
|||
|
||||
try:
|
||||
filter_collection = await self.filtering.get_user_filter(
|
||||
user_localpart=target_user.localpart, filter_id=filter_id_int
|
||||
user_id=user_id, filter_id=filter_id_int
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code != 404:
|
||||
|
|
|
@ -178,7 +178,7 @@ class SyncRestServlet(RestServlet):
|
|||
else:
|
||||
try:
|
||||
filter_collection = await self.filtering.get_user_filter(
|
||||
user.localpart, filter_id
|
||||
user.to_string(), filter_id
|
||||
)
|
||||
except StoreError as err:
|
||||
if err.code != 404:
|
||||
|
|
|
@ -24,7 +24,7 @@ from synapse.storage.database import (
|
|||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -34,8 +34,9 @@ if TYPE_CHECKING:
|
|||
class FilteringWorkerStore(SQLBaseStore):
|
||||
@cached(num_args=2)
|
||||
async def get_user_filter(
|
||||
self, user_localpart: str, filter_id: Union[int, str]
|
||||
self, user_id: str, filter_id: Union[int, str]
|
||||
) -> JsonDict:
|
||||
user_localpart = UserID.from_string(user_id).localpart
|
||||
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
|
||||
# with a coherent error message rather than 500 M_UNKNOWN.
|
||||
try:
|
||||
|
@ -43,13 +44,27 @@ class FilteringWorkerStore(SQLBaseStore):
|
|||
except ValueError:
|
||||
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
|
||||
|
||||
def_json = await self.db_pool.simple_select_one_onecol(
|
||||
table="user_filters",
|
||||
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
|
||||
retcol="filter_json",
|
||||
allow_none=False,
|
||||
desc="get_user_filter",
|
||||
)
|
||||
user_localpart = UserID.from_string(user_id).localpart
|
||||
try:
|
||||
def_json = await self.db_pool.simple_select_one_onecol(
|
||||
table="user_filters",
|
||||
keyvalues={"full_user_id": user_id, "filter_id": filter_id},
|
||||
retcol="filter_json",
|
||||
allow_none=False,
|
||||
desc="get_user_filter",
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
# Fall back to the `user_id` column.
|
||||
def_json = await self.db_pool.simple_select_one_onecol(
|
||||
table="user_filters",
|
||||
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
|
||||
retcol="filter_json",
|
||||
allow_none=False,
|
||||
desc="get_user_filter",
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
return db_to_json(def_json)
|
||||
|
||||
|
|
|
@ -33,7 +33,9 @@ from synapse.util.frozenutils import freeze
|
|||
from tests import unittest
|
||||
from tests.events.test_utils import MockEvent
|
||||
|
||||
user_id = "@test_user:test"
|
||||
user_localpart = "test_user"
|
||||
user2_id = "@test_user2:test"
|
||||
|
||||
|
||||
class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
|
@ -453,9 +455,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||
]
|
||||
|
||||
user_filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
results = self.get_success(user_filter.filter_presence(presence_states))
|
||||
|
@ -483,9 +483,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||
]
|
||||
|
||||
user_filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart + "2", filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user2_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
results = self.get_success(user_filter.filter_presence(presence_states))
|
||||
|
@ -502,9 +500,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||
events = [event]
|
||||
|
||||
user_filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
results = self.get_success(user_filter.filter_room_state(events=events))
|
||||
|
@ -523,9 +519,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||
events = [event]
|
||||
|
||||
user_filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
results = self.get_success(user_filter.filter_room_state(events))
|
||||
|
@ -607,9 +601,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||
user_filter_json,
|
||||
(
|
||||
self.get_success(
|
||||
self.datastore.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=0
|
||||
)
|
||||
self.datastore.get_user_filter(user_id=user_id, filter_id=0)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
@ -624,9 +616,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
self.assertEqual(filter.get_filter_json(), user_filter_json)
|
||||
|
|
|
@ -45,7 +45,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body, {"filter_id": "0"})
|
||||
filter = self.get_success(
|
||||
self.store.get_user_filter(user_localpart="apple", filter_id=0)
|
||||
self.store.get_user_filter(user_id="@apple:test", filter_id=0)
|
||||
)
|
||||
self.pump()
|
||||
self.assertEqual(filter, self.EXAMPLE_FILTER)
|
||||
|
|
Loading…
Reference in a new issue