Include approximate count of search results

This commit is contained in:
Erik Johnston 2015-12-11 11:40:23 +00:00
parent 51fb590c0e
commit d9a5c56930
2 changed files with 61 additions and 3 deletions

View file

@ -152,11 +152,15 @@ class SearchHandler(BaseHandler):
highlights = set() highlights = set()
count = None
if order_by == "rank": if order_by == "rank":
search_result = yield self.store.search_msgs( search_result = yield self.store.search_msgs(
room_ids, search_term, keys room_ids, search_term, keys
) )
count = search_result["count"]
if search_result["highlights"]: if search_result["highlights"]:
highlights.update(search_result["highlights"]) highlights.update(search_result["highlights"])
@ -207,6 +211,8 @@ class SearchHandler(BaseHandler):
if search_result["highlights"]: if search_result["highlights"]:
highlights.update(search_result["highlights"]) highlights.update(search_result["highlights"])
count = search_result["count"]
results = search_result["results"] results = search_result["results"]
results_map = {r["event"].event_id: r for r in results} results_map = {r["event"].event_id: r for r in results}
@ -359,7 +365,7 @@ class SearchHandler(BaseHandler):
rooms_cat_res = { rooms_cat_res = {
"results": results, "results": results,
"count": len(results), "count": count,
"highlights": list(highlights), "highlights": list(highlights),
} }

View file

@ -162,6 +162,9 @@ class SearchStore(BackgroundUpdateStore):
"(%s)" % (" OR ".join(local_clauses),) "(%s)" % (" OR ".join(local_clauses),)
) )
count_args = args
count_clauses = clauses
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
sql = ( sql = (
"SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank," "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank,"
@ -170,6 +173,12 @@ class SearchStore(BackgroundUpdateStore):
" WHERE vector @@ to_tsquery('english', ?)" " WHERE vector @@ to_tsquery('english', ?)"
) )
args = [search_query, search_query] + args args = [search_query, search_query] + args
count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
" WHERE vector @@ to_tsquery('english', ?)"
)
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine): elif isinstance(self.database_engine, Sqlite3Engine):
sql = ( sql = (
"SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
@ -177,6 +186,12 @@ class SearchStore(BackgroundUpdateStore):
" WHERE value MATCH ?" " WHERE value MATCH ?"
) )
args = [search_query] + args args = [search_query] + args
count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
" WHERE value MATCH ? AND "
)
count_args = [search_term] + count_args
else: else:
# This should be unreachable. # This should be unreachable.
raise Exception("Unrecognized database engine") raise Exception("Unrecognized database engine")
@ -184,6 +199,9 @@ class SearchStore(BackgroundUpdateStore):
for clause in clauses: for clause in clauses:
sql += " AND " + clause sql += " AND " + clause
for clause in count_clauses:
count_sql += " AND " + clause
# We add an arbitrary limit here to ensure we don't try to pull the # We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database. # entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500" sql += " ORDER BY rank DESC LIMIT 500"
@ -205,6 +223,14 @@ class SearchStore(BackgroundUpdateStore):
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
highlights = yield self._find_highlights_in_postgres(search_query, events) highlights = yield self._find_highlights_in_postgres(search_query, events)
count_sql += " GROUP BY room_id"
count_results = yield self._execute(
"search_rooms_count", self.cursor_to_dict, count_sql, *count_args
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
defer.returnValue({ defer.returnValue({
"results": [ "results": [
{ {
@ -215,6 +241,7 @@ class SearchStore(BackgroundUpdateStore):
if r["event_id"] in event_map if r["event_id"] in event_map
], ],
"highlights": highlights, "highlights": highlights,
"count": count,
}) })
@defer.inlineCallbacks @defer.inlineCallbacks
@ -254,6 +281,9 @@ class SearchStore(BackgroundUpdateStore):
"(%s)" % (" OR ".join(local_clauses),) "(%s)" % (" OR ".join(local_clauses),)
) )
count_args = args
count_clauses = clauses
if pagination_token: if pagination_token:
try: try:
origin_server_ts, stream = pagination_token.split(",") origin_server_ts, stream = pagination_token.split(",")
@ -276,7 +306,13 @@ class SearchStore(BackgroundUpdateStore):
" NATURAL JOIN events" " NATURAL JOIN events"
" WHERE vector @@ to_tsquery('english', ?) AND " " WHERE vector @@ to_tsquery('english', ?) AND "
) )
args = [search_term, search_term] + args args = [search_query, search_query] + args
count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
" WHERE vector @@ to_tsquery('english', ?) AND "
)
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine): elif isinstance(self.database_engine, Sqlite3Engine):
# We use CROSS JOIN here to ensure we use the right indexes. # We use CROSS JOIN here to ensure we use the right indexes.
# https://sqlite.org/optoverview.html#crossjoin # https://sqlite.org/optoverview.html#crossjoin
@ -296,12 +332,19 @@ class SearchStore(BackgroundUpdateStore):
" CROSS JOIN events USING (event_id)" " CROSS JOIN events USING (event_id)"
" WHERE " " WHERE "
) )
args = [search_term] + args args = [search_query] + args
count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
" WHERE value MATCH ? AND "
)
count_args = [search_term] + count_args
else: else:
# This should be unreachable. # This should be unreachable.
raise Exception("Unrecognized database engine") raise Exception("Unrecognized database engine")
sql += " AND ".join(clauses) sql += " AND ".join(clauses)
count_sql += " AND ".join(count_clauses)
# We add an arbitrary limit here to ensure we don't try to pull the # We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database. # entire table from the database.
@ -326,6 +369,14 @@ class SearchStore(BackgroundUpdateStore):
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
highlights = yield self._find_highlights_in_postgres(search_query, events) highlights = yield self._find_highlights_in_postgres(search_query, events)
count_sql += " GROUP BY room_id"
count_results = yield self._execute(
"search_rooms_count", self.cursor_to_dict, count_sql, *count_args
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
defer.returnValue({ defer.returnValue({
"results": [ "results": [
{ {
@ -339,6 +390,7 @@ class SearchStore(BackgroundUpdateStore):
if r["event_id"] in event_map if r["event_id"] in event_map
], ],
"highlights": highlights, "highlights": highlights,
"count": count,
}) })
def _find_highlights_in_postgres(self, search_query, events): def _find_highlights_in_postgres(self, search_query, events):