Convert simple_select_one_txn and simple_select_one to return tuples. (#16612)

This commit is contained in:
Patrick Cloke 2023-11-09 11:13:31 -05:00 committed by GitHub
parent ff716b483b
commit ab3f1b3b53
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
33 changed files with 283 additions and 279 deletions

1
changelog.d/16612.misc Normal file
View file

@ -0,0 +1 @@
Improve type hints.

View file

@ -348,8 +348,7 @@ class Porter:
backward_chunk = 0
already_ported = 0
else:
forward_chunk = row["forward_rowid"]
backward_chunk = row["backward_rowid"]
forward_chunk, backward_chunk = row
if total_to_port is None:
already_ported, total_to_port = await self._get_total_count_to_port(

View file

@ -269,7 +269,7 @@ class RoomCreationHandler:
self,
requester: Requester,
old_room_id: str,
old_room: Dict[str, Any],
old_room: Tuple[bool, str, bool],
new_room_id: str,
new_version: RoomVersion,
tombstone_event: EventBase,
@ -279,7 +279,7 @@ class RoomCreationHandler:
Args:
requester: the user requesting the upgrade
old_room_id: the id of the room to be replaced
old_room: a dict containing room information for the room to be replaced,
old_room: a tuple containing room information for the room to be replaced,
as returned by `RoomWorkerStore.get_room`.
new_room_id: the id of the replacement room
new_version: the version to upgrade the room to
@ -299,7 +299,7 @@ class RoomCreationHandler:
await self.store.store_room(
room_id=new_room_id,
room_creator_user_id=user_id,
is_public=old_room["is_public"],
is_public=old_room[0],
room_version=new_version,
)

View file

@ -1260,7 +1260,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Add new room to the room directory if the old room was there
# Remove old room from the room directory
old_room = await self.store.get_room(old_room_id)
if old_room is not None and old_room["is_public"]:
# If the old room exists and is public.
if old_room is not None and old_room[0]:
await self.store.set_room_is_public(old_room_id, False)
await self.store.set_room_is_public(room_id, True)

View file

@ -1860,7 +1860,8 @@ class PublicRoomListManager:
if not room:
return False
return room.get("is_public", False)
# The first item is whether the room is public.
return room[0]
async def add_room_to_public_room_list(self, room_id: str) -> None:
"""Publishes a room to the public room list.

View file

@ -413,8 +413,8 @@ class RoomMembersRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id)
if not ret:
room = await self.store.get_room(room_id)
if not room:
raise NotFoundError("Room not found")
members = await self.store.get_users_in_room(room_id)
@ -442,8 +442,8 @@ class RoomStateRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id)
if not ret:
room = await self.store.get_room(room_id)
if not room:
raise NotFoundError("Room not found")
event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)

View file

@ -147,7 +147,7 @@ class ClientDirectoryListServer(RestServlet):
if room is None:
raise NotFoundError("Unknown room")
return 200, {"visibility": "public" if room["is_public"] else "private"}
return 200, {"visibility": "public" if room[0] else "private"}
class PutBody(RequestBodyModel):
visibility: Literal["public", "private"] = "public"

View file

@ -1597,7 +1597,7 @@ class DatabasePool:
retcols: Collection[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one",
) -> Dict[str, Any]:
) -> Tuple[Any, ...]:
...
@overload
@ -1608,7 +1608,7 @@ class DatabasePool:
retcols: Collection[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[Any, ...]]:
...
async def simple_select_one(
@ -1618,7 +1618,7 @@ class DatabasePool:
retcols: Collection[str],
allow_none: bool = False,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
@ -2127,7 +2127,7 @@ class DatabasePool:
keyvalues: Dict[str, Any],
retcols: Collection[str],
allow_none: bool = False,
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[Any, ...]]:
select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
if keyvalues:
@ -2145,7 +2145,7 @@ class DatabasePool:
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
return dict(zip(retcols, row))
return row
async def simple_delete_one(
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"

View file

@ -255,33 +255,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
A dict containing the device information, or `None` if the device does not
exist.
"""
return await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
allow_none=True,
)
async def get_device_opt(
self, user_id: str, device_id: str
) -> Optional[Dict[str, Any]]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
Args:
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
A dict containing the device information, or None if the device does not exist.
"""
return await self.db_pool.simple_select_one(
row = await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
allow_none=True,
)
if row is None:
return None
return {"user_id": row[0], "device_id": row[1], "display_name": row[2]}
async def get_devices_by_user(
self, user_id: str
@ -1221,9 +1204,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
retcols=["device_id", "device_data"],
allow_none=True,
)
return (
(row["device_id"], json_decoder.decode(row["device_data"])) if row else None
)
return (row[0], json_decoder.decode(row[1])) if row else None
def _store_dehydrated_device_txn(
self,
@ -2326,13 +2307,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
`FALSE` have not been converted.
"""
row = await self.db_pool.simple_select_one(
return cast(
Tuple[int, str],
await self.db_pool.simple_select_one(
table="device_lists_changes_converted_stream_position",
keyvalues={},
retcols=["stream_id", "room_id"],
desc="get_device_change_last_converted_pos",
),
)
return row["stream_id"], row["room_id"]
async def set_device_change_last_converted_pos(
self,

View file

@ -506,19 +506,26 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
# it isn't there.
raise StoreError(404, "No backup with that version exists")
result = self.db_pool.simple_select_one_txn(
row = cast(
Tuple[int, str, str, Optional[int]],
self.db_pool.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
keyvalues={
"user_id": user_id,
"version": this_version,
"deleted": 0,
},
retcols=("version", "algorithm", "auth_data", "etag"),
allow_none=False,
),
)
assert result is not None # see comment on `simple_select_one_txn`
result["auth_data"] = db_to_json(result["auth_data"])
result["version"] = str(result["version"])
if result["etag"] is None:
result["etag"] = 0
return result
return {
"auth_data": db_to_json(row[2]),
"version": str(row[0]),
"algorithm": row[1],
"etag": 0 if row[3] is None else row[3],
}
return await self.db_pool.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn

View file

@ -1266,9 +1266,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
if row is None:
continue
key_id = row["key_id"]
key_json = row["key_json"]
used = row["used"]
key_id, key_json, used = row
# Mark fallback key as used if not already.
if not used and mark_as_used:

View file

@ -193,7 +193,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
# If the room has an auth chain index.
if room[1]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_ids_chains",
@ -411,7 +412,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
# If the room has an auth chain index.
if room[1]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_difference_chains",
@ -1437,24 +1439,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
if event_lookup_result is not None:
event_type, depth, stream_ordering = event_lookup_result
logger.debug(
"_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s",
room_id,
seed_event_id,
event_lookup_result["depth"],
event_lookup_result["stream_ordering"],
event_lookup_result["type"],
depth,
stream_ordering,
event_type,
)
if event_lookup_result["depth"]:
queue.put(
(
-event_lookup_result["depth"],
-event_lookup_result["stream_ordering"],
seed_event_id,
event_lookup_result["type"],
)
)
if depth:
queue.put((-depth, -stream_ordering, seed_event_id, event_type))
while not queue.empty() and len(event_id_results) < limit:
try:

View file

@ -1934,8 +1934,7 @@ class PersistEventsStore:
if row is None:
return
redacted_relates_to = row["relates_to_id"]
rel_type = row["relation_type"]
redacted_relates_to, rel_type = row
self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)

View file

@ -1998,7 +1998,7 @@ class EventsWorkerStore(SQLBaseStore):
if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,))
return int(res["topological_ordering"]), int(res["stream_ordering"])
return int(res[0]), int(res[1])
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry

View file

@ -208,7 +208,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
if row is None:
return None
return LocalMedia(media_id=media_id, **row)
return LocalMedia(
media_id=media_id,
media_type=row[0],
media_length=row[1],
upload_name=row[2],
created_ts=row[3],
quarantined_by=row[4],
url_cache=row[5],
last_access_ts=row[6],
safe_from_quarantine=row[7],
)
async def get_local_media_by_user_paginate(
self,
@ -541,7 +551,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
if row is None:
return row
return RemoteMedia(media_origin=origin, media_id=media_id, **row)
return RemoteMedia(
media_origin=origin,
media_id=media_id,
media_type=row[0],
media_length=row[1],
upload_name=row[2],
created_ts=row[3],
filesystem_id=row[4],
last_access_ts=row[5],
quarantined_by=row[6],
)
async def store_cached_remote_media(
self,
@ -665,11 +685,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
if row is None:
return None
return ThumbnailInfo(
width=row["thumbnail_width"],
height=row["thumbnail_height"],
method=row["thumbnail_method"],
type=row["thumbnail_type"],
length=row["thumbnail_length"],
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
)
@trace

View file

@ -13,7 +13,6 @@
# limitations under the License.
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@ -138,23 +137,18 @@ class ProfileWorkerStore(SQLBaseStore):
return 50
async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
try:
profile = await self.db_pool.simple_select_one(
table="profiles",
keyvalues={"full_user_id": user_id.to_string()},
retcols=("displayname", "avatar_url"),
desc="get_profileinfo",
allow_none=True,
)
except StoreError as e:
if e.code == 404:
if profile is None:
# no match
return ProfileInfo(None, None)
else:
raise
return ProfileInfo(
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
return ProfileInfo(avatar_url=profile[1], display_name=profile[0])
async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(

View file

@ -468,8 +468,7 @@ class PushRuleStore(PushRulesWorkerStore):
"before/after rule not found: %s" % (relative_to_rule,)
)
base_priority_class = res["priority_class"]
base_rule_priority = res["priority"]
base_priority_class, base_rule_priority = res
if base_priority_class != priority_class:
raise InconsistentRuleException(

View file

@ -701,8 +701,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
allow_none=True,
)
stream_ordering = int(res["stream_ordering"]) if res else None
rx_ts = res["received_ts"] if res else 0
stream_ordering = int(res[0]) if res else None
rx_ts = res[1] if res else 0
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts

View file

@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
account timestamp as milliseconds since the epoch. None if the account
has not been renewed using the current token yet.
"""
ret_dict = await self.db_pool.simple_select_one(
return cast(
Tuple[str, int, Optional[int]],
await self.db_pool.simple_select_one(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
desc="get_user_from_renewal_token",
)
return (
ret_dict["user_id"],
ret_dict["expiration_ts_ms"],
ret_dict["token_used_ts_ms"],
),
)
async def get_renewal_token_for_user(self, user_id: str) -> str:
@ -989,16 +986,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns:
user id, or None if no user id/threepid mapping exists
"""
ret = self.db_pool.simple_select_one_txn(
return self.db_pool.simple_select_one_onecol_txn(
txn,
"user_threepids",
{"medium": medium, "address": address},
["user_id"],
"user_id",
True,
)
if ret:
return ret["user_id"]
return None
async def user_add_threepid(
self,
@ -1435,16 +1429,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
if res is None:
return False
uses_allowed, pending, completed, expiry_time = res
# Check if the token has expired
now = self._clock.time_msec()
if res["expiry_time"] and res["expiry_time"] < now:
if expiry_time and expiry_time < now:
return False
# Check if the token has been used up
if (
res["uses_allowed"]
and res["pending"] + res["completed"] >= res["uses_allowed"]
):
if uses_allowed and pending + completed >= uses_allowed:
return False
# Otherwise, the token is valid
@ -1490,8 +1483,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors
# about None not being indexable.
res = cast(
Dict[str, Any],
pending, completed = cast(
Tuple[int, int],
self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
@ -1506,8 +1499,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"registration_tokens",
keyvalues={"token": token},
updatevalues={
"completed": res["completed"] + 1,
"pending": res["pending"] - 1,
"completed": completed + 1,
"pending": pending - 1,
},
)
@ -1585,13 +1578,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns:
A dict, or None if token doesn't exist.
"""
return await self.db_pool.simple_select_one(
row = await self.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
allow_none=True,
desc="get_one_registration_token",
)
if row is None:
return None
return {
"token": row[0],
"uses_allowed": row[1],
"pending": row[2],
"completed": row[3],
"expiry_time": row[4],
}
async def generate_registration_token(
self, length: int, chars: str
@ -1714,7 +1716,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return None
# Get all info about the token so it can be sent in the response
return self.db_pool.simple_select_one_txn(
result = self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
@ -1728,6 +1730,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
allow_none=True,
)
if result is None:
return result
return {
"token": result[0],
"uses_allowed": result[1],
"pending": result[2],
"completed": result[3],
"expiry_time": result[4],
}
return await self.db_pool.runInteraction(
"update_registration_token", _update_registration_token_txn
)
@ -1939,11 +1952,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
keyvalues={"token": token},
updatevalues={"used_ts": ts},
)
user_id = values["user_id"]
expiry_ts = values["expiry_ts"]
used_ts = values["used_ts"]
auth_provider_id = values["auth_provider_id"]
auth_provider_session_id = values["auth_provider_session_id"]
(
user_id,
expiry_ts,
used_ts,
auth_provider_id,
auth_provider_session_id,
) = values
# Token was already used
if used_ts is not None:
@ -2756,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
# reason, the next check is on the client secret, which is NOT NULL,
# so we don't have to worry about the client secret matching by
# accident.
row = {"client_secret": None, "validated_at": None}
row = None, None
else:
raise ThreepidValidationError("Unknown session_id")
retrieved_client_secret = row["client_secret"]
validated_at = row["validated_at"]
retrieved_client_secret, validated_at = row
row = self.db_pool.simple_select_one_txn(
txn,
@ -2775,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
raise ThreepidValidationError(
"Validation token not found or has expired"
)
expires = row["expires"]
next_link = row["next_link"]
expires, next_link = row
if retrieved_client_secret != client_secret:
raise ThreepidValidationError(

View file

@ -213,21 +213,31 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
async def get_room(self, room_id: str) -> Optional[Tuple[bool, bool]]:
"""Retrieve a room.
Args:
room_id: The ID of the room to retrieve.
Returns:
A dict containing the room information, or None if the room is unknown.
A tuple containing the room information:
* True if the room is public
* True if the room has an auth chain index
or None if the room is unknown.
"""
return await self.db_pool.simple_select_one(
row = cast(
Optional[Tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]],
await self.db_pool.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
retcols=("is_public", "has_auth_chain_index"),
desc="get_room",
allow_none=True,
),
)
if row is None:
return row
return bool(row[0]), bool(row[1])
async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
"""Retrieve room with statistics.
@ -794,10 +804,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
)
if row:
return RatelimitOverride(
messages_per_second=row["messages_per_second"],
burst_count=row["burst_count"],
)
return RatelimitOverride(messages_per_second=row[0], burst_count=row[1])
else:
return None
@ -1371,13 +1378,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
join.
"""
result = await self.db_pool.simple_select_one(
return cast(
Tuple[str, int],
await self.db_pool.simple_select_one(
table="partial_state_rooms",
keyvalues={"room_id": room_id},
retcols=("join_event_id", "device_lists_stream_id"),
desc="get_join_event_id_for_partial_state",
),
)
return result["join_event_id"], result["device_lists_stream_id"]
def get_un_partial_stated_rooms_token(self, instance_name: str) -> int:
return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer(

View file

@ -559,17 +559,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
"non-local user %s" % (user_id,),
)
results_dict = await self.db_pool.simple_select_one(
results = cast(
Optional[Tuple[str, str]],
await self.db_pool.simple_select_one(
"local_current_membership",
{"room_id": room_id, "user_id": user_id},
("membership", "event_id"),
allow_none=True,
desc="get_local_current_membership_for_user_in_room",
),
)
if not results_dict:
if not results:
return None, None
return results_dict.get("membership"), results_dict.get("event_id")
return results
@cached(max_entries=500000, iterable=True)
async def get_rooms_for_user_with_stream_ordering(

View file

@ -1014,9 +1014,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="get_position_for_event",
)
return PersistedEventPosition(
row["instance_name"] or "master", row["stream_ordering"]
)
return PersistedEventPosition(row[1] or "master", row[0])
async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken:
"""The stream token for an event
@ -1033,9 +1031,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
)
return RoomStreamToken(
topological=row["topological_ordering"], stream=row["stream_ordering"]
)
return RoomStreamToken(topological=row[1], stream=row[0])
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
"""Gets the topological token in a room after or at the given stream
@ -1180,26 +1176,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict
"""
results = self.db_pool.simple_select_one_txn(
stream_ordering, topological_ordering = cast(
Tuple[int, int],
self.db_pool.simple_select_one_txn(
txn,
"events",
keyvalues={"event_id": event_id, "room_id": room_id},
retcols=["stream_ordering", "topological_ordering"],
),
)
# This cannot happen as `allow_none=False`.
assert results is not None
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
topological=results["topological_ordering"] - 1,
stream=results["stream_ordering"],
topological=topological_ordering - 1, stream=stream_ordering
)
after_token = RoomStreamToken(
topological=results["topological_ordering"],
stream=results["stream_ordering"],
topological=topological_ordering, stream=stream_ordering
)
rows, start_token = self._paginate_room_events_txn(

View file

@ -183,7 +183,9 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
Returns: the task if available, `None` otherwise
"""
row = await self.db_pool.simple_select_one(
row = cast(
Optional[ScheduledTaskRow],
await self.db_pool.simple_select_one(
table="scheduled_tasks",
keyvalues={"id": id},
retcols=(
@ -198,24 +200,10 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
),
allow_none=True,
desc="get_scheduled_task",
),
)
return (
TaskSchedulerWorkerStore._convert_row_to_task(
(
row["id"],
row["action"],
row["status"],
row["timestamp"],
row["resource_id"],
row["params"],
row["result"],
row["error"],
)
)
if row
else None
)
return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
async def delete_scheduled_task(self, id: str) -> None:
"""Delete a specific task from its id.

View file

@ -118,19 +118,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
txn,
table="received_transactions",
keyvalues={"transaction_id": transaction_id, "origin": origin},
retcols=(
"transaction_id",
"origin",
"ts",
"response_code",
"response_json",
"has_been_referenced",
),
retcols=("response_code", "response_json"),
allow_none=True,
)
if result and result["response_code"]:
return result["response_code"], db_to_json(result["response_json"])
# If the result exists and the response code is non-0.
if result and result[0]:
return result[0], db_to_json(result[1])
else:
return None
@ -200,8 +194,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
# check we have a row and retry_last_ts is not null or zero
# (retry_last_ts can't be negative)
if result and result["retry_last_ts"]:
return DestinationRetryTimings(**result)
if result and result[1]:
return DestinationRetryTimings(
failure_ts=result[0], retry_last_ts=result[1], retry_interval=result[2]
)
else:
return None

View file

@ -122,9 +122,13 @@ class UIAuthWorkerStore(SQLBaseStore):
desc="get_ui_auth_session",
)
result["clientdict"] = db_to_json(result["clientdict"])
return UIAuthSessionData(session_id, **result)
return UIAuthSessionData(
session_id,
clientdict=db_to_json(result[0]),
uri=result[1],
method=result[2],
description=result[3],
)
async def mark_ui_auth_stage_complete(
self,
@ -231,18 +235,15 @@ class UIAuthWorkerStore(SQLBaseStore):
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
) -> None:
# Get the current value.
result = cast(
Dict[str, Any],
self.db_pool.simple_select_one_txn(
result = self.db_pool.simple_select_one_onecol_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
),
retcol="serverdict",
)
# Update it and add it back to the database.
serverdict = db_to_json(result["serverdict"])
serverdict = db_to_json(result)
serverdict[key] = value
self.db_pool.simple_update_one_txn(
@ -265,14 +266,14 @@ class UIAuthWorkerStore(SQLBaseStore):
Raises:
StoreError if the session cannot be found.
"""
result = await self.db_pool.simple_select_one(
result = await self.db_pool.simple_select_one_onecol(
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
retcol="serverdict",
desc="get_ui_auth_session_data",
)
serverdict = db_to_json(result["serverdict"])
serverdict = db_to_json(result)
return serverdict.get(key, default)

View file

@ -20,7 +20,6 @@ from typing import (
Collection,
Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
@ -833,13 +832,25 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
async def _get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
return await self.db_pool.simple_select_one(
async def _get_user_in_directory(
self, user_id: str
) -> Optional[Tuple[Optional[str], Optional[str]]]:
"""
Fetch the user information in the user directory.
Returns:
None if the user is unknown, otherwise a tuple of display name and
avatar URL (both of which may be None).
"""
return cast(
Optional[Tuple[Optional[str], Optional[str]]],
await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
allow_none=True,
desc="get_user_in_directory",
),
)
async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None:

View file

@ -84,7 +84,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
return self.get_success(
row = self.get_success(
self.store.db_pool.simple_select_one(
table + "_current",
{id_col: stat_id},
@ -93,6 +93,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
return None if row is None else dict(zip(cols, row))
def _perform_background_initial_update(self) -> None:
# Do the initial population of the stats via the background update
self._add_background_updates()

View file

@ -366,7 +366,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
profile = self.get_success(self.store._get_user_in_directory(regular_user_id))
assert profile is not None
self.assertTrue(profile["display_name"] == display_name)
self.assertTrue(profile[0] == display_name)
def test_handle_local_profile_change_with_deactivated_user(self) -> None:
# create user
@ -385,7 +385,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is in directory
profile = self.get_success(self.store._get_user_in_directory(r_user_id))
assert profile is not None
self.assertTrue(profile["display_name"] == display_name)
self.assertEqual(profile[0], display_name)
# deactivate user
self.get_success(self.store.set_user_deactivated_status(r_user_id, True))

View file

@ -2706,7 +2706,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# is in user directory
profile = self.get_success(self.store._get_user_in_directory(self.other_user))
assert profile is not None
self.assertTrue(profile["display_name"] == "User")
self.assertEqual(profile[0], "User")
# Deactivate user
channel = self.make_request(

View file

@ -139,12 +139,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
#
# Note that we don't have the UI Auth session ID, so just pull out the single
# row.
ui_auth_data = self.get_success(
self.store.db_pool.simple_select_one(
"ui_auth_sessions", keyvalues={}, retcols=("clientdict",)
result = self.get_success(
self.store.db_pool.simple_select_one_onecol(
"ui_auth_sessions", keyvalues={}, retcol="clientdict"
)
)
client_dict = db_to_json(ui_auth_data["clientdict"])
client_dict = db_to_json(result)
self.assertNotIn("new_password", client_dict)
@override_config({"rc_3pid_validation": {"burst_count": 3}})

View file

@ -270,15 +270,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertLessEqual(det_data.items(), channel.json_body.items())
# Check the `completed` counter has been incremented and pending is 0
res = self.get_success(
pending, completed = self.get_success(
store.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["pending", "completed"],
)
)
self.assertEqual(res["completed"], 1)
self.assertEqual(res["pending"], 0)
self.assertEqual(completed, 1)
self.assertEqual(pending, 0)
@override_config({"registration_requires_token": True})
def test_POST_registration_token_invalid(self) -> None:
@ -372,15 +372,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
params1["auth"]["type"] = LoginType.DUMMY
self.make_request(b"POST", self.url, params1)
# Check pending=0 and completed=1
res = self.get_success(
pending, completed = self.get_success(
store.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["pending", "completed"],
)
)
self.assertEqual(res["pending"], 0)
self.assertEqual(res["completed"], 1)
self.assertEqual(pending, 0)
self.assertEqual(completed, 1)
# Check auth still fails when using token with session2
channel = self.make_request(b"POST", self.url, params2)

View file

@ -222,7 +222,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
)
self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret)
self.assertEqual((1, 2, 3), ret)
self.mock_txn.execute.assert_called_once_with(
"SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"]
)
@ -243,7 +243,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
)
self.assertFalse(ret)
self.assertIsNone(ret)
@defer.inlineCallbacks
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:

View file

@ -42,16 +42,9 @@ class RoomStoreTestCase(HomeserverTestCase):
)
def test_get_room(self) -> None:
res = self.get_success(self.store.get_room(self.room.to_string()))
assert res is not None
self.assertLessEqual(
{
"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(),
"is_public": True,
}.items(),
res.items(),
)
room = self.get_success(self.store.get_room(self.room.to_string()))
assert room is not None
self.assertTrue(room[0])
def test_get_room_unknown_room(self) -> None:
self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))