Add basic full text search impl.

This commit is contained in:
Erik Johnston 2015-10-09 15:48:31 +01:00
parent db6e1e1fe3
commit c85c912562
8 changed files with 268 additions and 1 deletions

View file

@ -84,3 +84,22 @@ class RoomCreationPreset(object):
PRIVATE_CHAT = "private_chat" PRIVATE_CHAT = "private_chat"
PUBLIC_CHAT = "public_chat" PUBLIC_CHAT = "public_chat"
TRUSTED_PRIVATE_CHAT = "trusted_private_chat" TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
class SearchConstraintTypes(object):
FTS = "fts"
EXACT = "exact"
PREFIX = "prefix"
SUBSTRING = "substring"
RANGE = "range"
class KnownRoomEventKeys(object):
CONTENT_BODY = "content.body"
CONTENT_MSGTYPE = "content.msgtype"
CONTENT_NAME = "content.name"
CONTENT_TOPIC = "content.topic"
SENDER = "sender"
ORIGIN_SERVER_TS = "origin_server_ts"
ROOM_ID = "room_id"

View file

@ -32,6 +32,7 @@ from .sync import SyncHandler
from .auth import AuthHandler from .auth import AuthHandler
from .identity import IdentityHandler from .identity import IdentityHandler
from .receipts import ReceiptsHandler from .receipts import ReceiptsHandler
from .search import SearchHandler
class Handlers(object): class Handlers(object):
@ -68,3 +69,4 @@ class Handlers(object):
self.sync_handler = SyncHandler(hs) self.sync_handler = SyncHandler(hs)
self.auth_handler = AuthHandler(hs) self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs) self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs)

View file

@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import KnownRoomEventKeys, SearchConstraintTypes
from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event
import logging
logger = logging.getLogger(__name__)
KEYS_TO_ALLOWED_CONSTRAINT_TYPES = {
KnownRoomEventKeys.CONTENT_BODY: [SearchConstraintTypes.FTS],
KnownRoomEventKeys.CONTENT_MSGTYPE: [SearchConstraintTypes.EXACT],
KnownRoomEventKeys.CONTENT_NAME: [SearchConstraintTypes.FTS, SearchConstraintTypes.EXACT, SearchConstraintTypes.SUBSTRING],
KnownRoomEventKeys.CONTENT_TOPIC: [SearchConstraintTypes.FTS],
KnownRoomEventKeys.SENDER: [SearchConstraintTypes.EXACT],
KnownRoomEventKeys.ORIGIN_SERVER_TS: [SearchConstraintTypes.RANGE],
KnownRoomEventKeys.ROOM_ID: [SearchConstraintTypes.EXACT],
}
class RoomConstraint(object):
def __init__(self, search_type, keys, value):
self.search_type = search_type
self.keys = keys
self.value = value
@classmethod
def from_dict(cls, d):
search_type = d["type"]
keys = d["keys"]
for key in keys:
if key not in KEYS_TO_ALLOWED_CONSTRAINT_TYPES:
raise SynapseError(400, "Unrecognized key %r", key)
if search_type not in KEYS_TO_ALLOWED_CONSTRAINT_TYPES[key]:
raise SynapseError(400, "Disallowed constraint type %r for key %r", search_type, key)
return cls(search_type, keys, d["value"])
class SearchHandler(BaseHandler):
def __init__(self, hs):
super(SearchHandler, self).__init__(hs)
@defer.inlineCallbacks
def search(self, content):
constraint_dicts = content["search_categories"]["room_events"]["constraints"]
constraints = [RoomConstraint.from_dict(c)for c in constraint_dicts]
fts = False
for c in constraints:
if c.search_type == SearchConstraintTypes.FTS:
if fts:
raise SynapseError(400, "Only one constraint can be FTS")
fts = True
res = yield self.hs.get_datastore().search_msgs(constraints)
time_now = self.hs.get_clock().time_msec()
results = [
{
"rank": r["rank"],
"result": serialize_event(r["result"], time_now)
}
for r in res
]
logger.info("returning: %r", results)
results.sort(key=lambda r: -r["rank"])
defer.returnValue(results)

View file

@ -529,6 +529,22 @@ class RoomTypingRestServlet(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class SearchRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern(
"/search$"
)
@defer.inlineCallbacks
def on_POST(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
results = yield self.handlers.search_handler.search(content)
defer.returnValue((200, results))
def _parse_json(request): def _parse_json(request):
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
@ -585,3 +601,4 @@ def register_servlets(hs, http_server):
RoomInitialSyncRestServlet(hs).register(http_server) RoomInitialSyncRestServlet(hs).register(http_server)
RoomRedactEventRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server)
RoomTypingRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server)

View file

@ -40,6 +40,7 @@ from .filtering import FilteringStore
from .end_to_end_keys import EndToEndKeyStore from .end_to_end_keys import EndToEndKeyStore
from .receipts import ReceiptsStore from .receipts import ReceiptsStore
from .search import SearchStore
import fnmatch import fnmatch
@ -79,6 +80,7 @@ class DataStore(RoomMemberStore, RoomStore,
EventsStore, EventsStore,
ReceiptsStore, ReceiptsStore,
EndToEndKeyStore, EndToEndKeyStore,
SearchStore,
): ):
def __init__(self, hs): def __init__(self, hs):

View file

@ -519,7 +519,7 @@ class SQLBaseStore(object):
allow_none=False, allow_none=False,
desc="_simple_select_one_onecol"): desc="_simple_select_one_onecol"):
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it." return a single row, returning a single column from it.
Args: Args:
table : string giving the table name table : string giving the table name

View file

@ -0,0 +1,57 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.storage import get_statements
from synapse.storage.engines import PostgresEngine
logger = logging.getLogger(__name__)
POSTGRES_SQL = """
CREATE TABLE event_search (
event_id TEXT,
room_id TEXT,
key TEXT,
vector tsvector
);
INSERT INTO event_search SELECT
event_id, room_id, 'content.body',
to_tsvector('english', json::json->'content'->>'body')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.message';
INSERT INTO event_search SELECT
event_id, room_id, 'content.name',
to_tsvector('english', json::json->'content'->>'name')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.name';
INSERT INTO event_search SELECT
event_id, room_id, 'content.topic',
to_tsvector('english', json::json->'content'->>'topic')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.topic';
CREATE INDEX event_search_idx ON event_search USING gin(vector);
"""
def run_upgrade(cur, database_engine, *args, **kwargs):
if not isinstance(database_engine, PostgresEngine):
# We only support FTS for postgres currently.
return
for statement in get_statements(POSTGRES_SQL.splitlines()):
cur.execute(statement)

75
synapse/storage/search.py Normal file
View file

@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from _base import SQLBaseStore
from synapse.api.constants import KnownRoomEventKeys, SearchConstraintTypes
class SearchStore(SQLBaseStore):
@defer.inlineCallbacks
def search_msgs(self, constraints):
clauses = []
args = []
fts = None
for c in constraints:
local_clauses = []
if c.search_type == SearchConstraintTypes.FTS:
fts = c.value
for key in c.keys:
local_clauses.append("key = ?")
args.append(key)
elif c.search_type == SearchConstraintTypes.EXACT:
for key in c.keys:
if key == KnownRoomEventKeys.ROOM_ID:
for value in c.value:
local_clauses.append("room_id = ?")
args.append(value)
clauses.append(
"(%s)" % (" OR ".join(local_clauses),)
)
sql = (
"SELECT ts_rank_cd(vector, query) AS rank, event_id"
" FROM plainto_tsquery('english', ?) as query, event_search"
" WHERE vector @@ query"
)
for clause in clauses:
sql += " AND " + clause
sql += " ORDER BY rank DESC"
results = yield self._execute(
"search_msgs", self.cursor_to_dict, sql, *([fts] + args)
)
events = yield self._get_events([r["event_id"] for r in results])
event_map = {
ev.event_id: ev
for ev in events
}
defer.returnValue([
{
"rank": r["rank"],
"result": event_map[r["event_id"]]
}
for r in results
if r["event_id"] in event_map
])