diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index 650c49982d..ebcc8b9633 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -19,6 +19,7 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.api.constants import EventTypes, JoinRules from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import get_domain_from_id, get_localpart_from_id class UserDirectoryStore(SQLBaseStore): @@ -50,26 +51,39 @@ class UserDirectoryStore(SQLBaseStore): sql = """ INSERT INTO user_directory (user_id, room_id, display_name, avatar_url, vector) - VALUES (?,?,?,?,to_tsvector('english', ?)) + VALUES (?,?,?,?, + setweight(to_tsvector('english', ?), 'A') + || to_tsvector('english', ?) + || to_tsvector('english', COALESCE(?, '')) + ) """ + args = ( + ( + user_id, room_id, p.display_name, p.avatar_url, + get_localpart_from_id(user_id), get_domain_from_id(user_id), + p.display_name, + ) + for user_id, p in users_with_profile.iteritems() + ) elif isinstance(self.database_engine, Sqlite3Engine): sql = """ INSERT INTO user_directory (user_id, room_id, display_name, avatar_url, value) VALUES (?,?,?,?,?) """ - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - - def _add_profiles_to_user_dir_txn(txn): - txn.executemany(sql, ( + args = ( ( user_id, room_id, p.display_name, p.avatar_url, "%s %s" % (user_id, p.display_name,) if p.display_name else user_id ) for user_id, p in users_with_profile.iteritems() - )) + ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + def _add_profiles_to_user_dir_txn(txn): + txn.executemany(sql, args) for user_id in users_with_profile: txn.call_after( self.get_user_in_directory.invalidate, (user_id,) @@ -160,8 +174,8 @@ class UserDirectoryStore(SQLBaseStore): sql = """ SELECT user_id, display_name, avatar_url FROM user_directory - WHERE vector @@ to_tsquery('english', ?) - ORDER BY ts_rank_cd(vector, to_tsquery('english', ?)) DESC + WHERE vector @@ plainto_tsquery('english', ?) + ORDER BY ts_rank_cd(vector, plainto_tsquery('english', ?)) DESC LIMIT ? """ args = (search_term, search_term, limit + 1,) diff --git a/synapse/types.py b/synapse/types.py index 445bdcb4d7..111948540d 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -62,6 +62,13 @@ def get_domain_from_id(string): return string[idx + 1:] +def get_localpart_from_id(string): + idx = string.find(":") + if idx == -1: + raise SynapseError(400, "Invalid ID: %r" % (string,)) + return string[1:idx] + + class DomainSpecificString( namedtuple("DomainSpecificString", ("localpart", "domain")) ):