Convert simple_select_list_paginate_txn to return tuples. (#16433)

This commit is contained in:
Patrick Cloke 2023-10-06 11:41:57 -04:00 committed by GitHub
parent 7615e2bf48
commit 06bbf1029c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 67 additions and 39 deletions

1
changelog.d/16433.misc Normal file
View file

@ -0,0 +1 @@
Reduce memory allocations.

View file

@ -80,10 +80,6 @@ class UserPresenceState:
def as_dict(self) -> JsonDict: def as_dict(self) -> JsonDict:
return attr.asdict(self) return attr.asdict(self)
@staticmethod
def from_dict(d: JsonDict) -> "UserPresenceState":
return UserPresenceState(**d)
def copy_and_replace(self, **kwargs: Any) -> "UserPresenceState": def copy_and_replace(self, **kwargs: Any) -> "UserPresenceState":
return attr.evolve(self, **kwargs) return attr.evolve(self, **kwargs)

View file

@ -395,7 +395,7 @@ class PresenceDestinationsRow(BaseFederationRow):
@staticmethod @staticmethod
def from_data(data: JsonDict) -> "PresenceDestinationsRow": def from_data(data: JsonDict) -> "PresenceDestinationsRow":
return PresenceDestinationsRow( return PresenceDestinationsRow(
state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"] state=UserPresenceState(**data["state"]), destinations=data["dests"]
) )
def to_data(self) -> JsonDict: def to_data(self) -> JsonDict:

View file

@ -198,7 +198,13 @@ class DestinationMembershipRestServlet(RestServlet):
rooms, total = await self._store.get_destination_rooms_paginate( rooms, total = await self._store.get_destination_rooms_paginate(
destination, start, limit, direction destination, start, limit, direction
) )
response = {"rooms": rooms, "total": total} response = {
"rooms": [
{"room_id": room_id, "stream_ordering": stream_ordering}
for room_id, stream_ordering in rooms
],
"total": total,
}
if (start + limit) < total: if (start + limit) < total:
response["next_token"] = str(start + len(rooms)) response["next_token"] = str(start + len(rooms))

View file

@ -2418,7 +2418,7 @@ class DatabasePool:
keyvalues: Optional[Dict[str, Any]] = None, keyvalues: Optional[Dict[str, Any]] = None,
exclude_keyvalues: Optional[Dict[str, Any]] = None, exclude_keyvalues: Optional[Dict[str, Any]] = None,
order_direction: str = "ASC", order_direction: str = "ASC",
) -> List[Dict[str, Any]]: ) -> List[Tuple[Any, ...]]:
""" """
Executes a SELECT query on the named table with start and limit, Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit, of row numbers, which may return zero or number of rows from start to limit,
@ -2447,7 +2447,7 @@ class DatabasePool:
order_direction: Whether the results should be ordered "ASC" or "DESC". order_direction: Whether the results should be ordered "ASC" or "DESC".
Returns: Returns:
The result as a list of dictionaries. The result as a list of tuples.
""" """
if order_direction not in ["ASC", "DESC"]: if order_direction not in ["ASC", "DESC"]:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
@ -2474,7 +2474,7 @@ class DatabasePool:
) )
txn.execute(sql, arg_list + [limit, start]) txn.execute(sql, arg_list + [limit, start])
return cls.cursor_to_dict(txn) return txn.fetchall()
async def simple_search_list( async def simple_search_list(
self, self,

View file

@ -20,6 +20,7 @@ from typing import (
Mapping, Mapping,
Optional, Optional,
Tuple, Tuple,
Union,
cast, cast,
) )
@ -385,7 +386,9 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
limit = 100 limit = 100
offset = 0 offset = 0
while True: while True:
rows = await self.db_pool.runInteraction( rows = cast(
List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
await self.db_pool.runInteraction(
"get_presence_for_all_users", "get_presence_for_all_users",
self.db_pool.simple_select_list_paginate_txn, self.db_pool.simple_select_list_paginate_txn,
"presence_stream", "presence_stream",
@ -403,10 +406,27 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
"currently_active", "currently_active",
), ),
order_direction="ASC", order_direction="ASC",
),
) )
for row in rows: for (
users_to_state[row["user_id"]] = UserPresenceState(**row) user_id,
state,
last_active_ts,
last_federation_update_ts,
last_user_sync_ts,
status_msg,
currently_active,
) in rows:
users_to_state[user_id] = UserPresenceState(
user_id=user_id,
state=state,
last_active_ts=last_active_ts,
last_federation_update_ts=last_federation_update_ts,
last_user_sync_ts=last_user_sync_ts,
status_msg=status_msg,
currently_active=bool(currently_active),
)
# We've run out of updates to query # We've run out of updates to query
if len(rows) < limit: if len(rows) < limit:

View file

@ -526,7 +526,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
start: int, start: int,
limit: int, limit: int,
direction: Direction = Direction.FORWARDS, direction: Direction = Direction.FORWARDS,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[Tuple[str, int]], int]:
"""Function to retrieve a paginated list of destination's rooms. """Function to retrieve a paginated list of destination's rooms.
This will return a json list of rooms and the This will return a json list of rooms and the
total number of rooms. total number of rooms.
@ -537,12 +537,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
limit: number of rows to retrieve limit: number of rows to retrieve
direction: sort ascending or descending by room_id direction: sort ascending or descending by room_id
Returns: Returns:
A tuple of a dict of rooms and a count of total rooms. A tuple of a list of room tuples and a count of total rooms.
Each room tuple is room_id, stream_ordering.
""" """
def get_destination_rooms_paginate_txn( def get_destination_rooms_paginate_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[Tuple[str, int]], int]:
if direction == Direction.BACKWARDS: if direction == Direction.BACKWARDS:
order = "DESC" order = "DESC"
else: else:
@ -556,7 +558,9 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql, [destination]) txn.execute(sql, [destination])
count = cast(Tuple[int], txn.fetchone())[0] count = cast(Tuple[int], txn.fetchone())[0]
rooms = self.db_pool.simple_select_list_paginate_txn( rooms = cast(
List[Tuple[str, int]],
self.db_pool.simple_select_list_paginate_txn(
txn=txn, txn=txn,
table="destination_rooms", table="destination_rooms",
orderby="room_id", orderby="room_id",
@ -564,6 +568,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
limit=limit, limit=limit,
retcols=("room_id", "stream_ordering"), retcols=("room_id", "stream_ordering"),
order_direction=order, order_direction=order,
),
) )
return rooms, count return rooms, count