De-localpart {Filtering,FilteringWorkerStore}.get_user_filter()

Signed-off-by: Sean Quah <seanq@matrix.org>
This commit is contained in:
Sean Quah 2023-04-15 02:37:47 +01:00
parent 06f9ababc4
commit f98141ceb2
6 changed files with 37 additions and 32 deletions

View file

@ -165,9 +165,9 @@ class Filtering:
self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
async def get_user_filter(
self, user_localpart: str, filter_id: Union[int, str]
self, user_id: str, filter_id: Union[int, str]
) -> "FilterCollection":
result = await self.store.get_user_filter(user_localpart, filter_id)
result = await self.store.get_user_filter(user_id, filter_id)
return FilterCollection(self._hs, result)
def add_user_filter(

View file

@ -58,7 +58,7 @@ class GetFilterRestServlet(RestServlet):
try:
filter_collection = await self.filtering.get_user_filter(
user_localpart=target_user.localpart, filter_id=filter_id_int
user_id=user_id, filter_id=filter_id_int
)
except StoreError as e:
if e.code != 404:

View file

@ -178,7 +178,7 @@ class SyncRestServlet(RestServlet):
else:
try:
filter_collection = await self.filtering.get_user_filter(
user.localpart, filter_id
user.to_string(), filter_id
)
except StoreError as err:
if err.code != 404:

View file

@ -24,7 +24,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.types import JsonDict
from synapse.types import JsonDict, UserID
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@ -34,8 +34,9 @@ if TYPE_CHECKING:
class FilteringWorkerStore(SQLBaseStore):
@cached(num_args=2)
async def get_user_filter(
self, user_localpart: str, filter_id: Union[int, str]
self, user_id: str, filter_id: Union[int, str]
) -> JsonDict:
user_localpart = UserID.from_string(user_id).localpart
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN.
try:
@ -43,13 +44,27 @@ class FilteringWorkerStore(SQLBaseStore):
except ValueError:
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
def_json = await self.db_pool.simple_select_one_onecol(
table="user_filters",
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json",
allow_none=False,
desc="get_user_filter",
)
user_localpart = UserID.from_string(user_id).localpart
try:
def_json = await self.db_pool.simple_select_one_onecol(
table="user_filters",
keyvalues={"full_user_id": user_id, "filter_id": filter_id},
retcol="filter_json",
allow_none=False,
desc="get_user_filter",
)
except StoreError as e:
if e.code == 404:
# Fall back to the `user_id` column.
def_json = await self.db_pool.simple_select_one_onecol(
table="user_filters",
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json",
allow_none=False,
desc="get_user_filter",
)
else:
raise
return db_to_json(def_json)

View file

@ -33,7 +33,9 @@ from synapse.util.frozenutils import freeze
from tests import unittest
from tests.events.test_utils import MockEvent
user_id = "@test_user:test"
user_localpart = "test_user"
user2_id = "@test_user2:test"
class FilteringTestCase(unittest.HomeserverTestCase):
@ -453,9 +455,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
]
user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
)
results = self.get_success(user_filter.filter_presence(presence_states))
@ -483,9 +483,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
]
user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart + "2", filter_id=filter_id
)
self.filtering.get_user_filter(user_id=user2_id, filter_id=filter_id)
)
results = self.get_success(user_filter.filter_presence(presence_states))
@ -502,9 +500,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
events = [event]
user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
)
results = self.get_success(user_filter.filter_room_state(events=events))
@ -523,9 +519,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
events = [event]
user_filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
)
results = self.get_success(user_filter.filter_room_state(events))
@ -607,9 +601,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
user_filter_json,
(
self.get_success(
self.datastore.get_user_filter(
user_localpart=user_localpart, filter_id=0
)
self.datastore.get_user_filter(user_id=user_id, filter_id=0)
)
),
)
@ -624,9 +616,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
filter = self.get_success(
self.filtering.get_user_filter(
user_localpart=user_localpart, filter_id=filter_id
)
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
)
self.assertEqual(filter.get_filter_json(), user_filter_json)

View file

@ -45,7 +45,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"filter_id": "0"})
filter = self.get_success(
self.store.get_user_filter(user_localpart="apple", filter_id=0)
self.store.get_user_filter(user_id="@apple:test", filter_id=0)
)
self.pump()
self.assertEqual(filter, self.EXAMPLE_FILTER)