Refactor get_user_by_id (#16316)

This commit is contained in:
Erik Johnston 2023-09-14 12:46:30 +01:00 committed by GitHub
parent 032cf84f52
commit 954921736b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 108 additions and 123 deletions

1
changelog.d/16316.misc Normal file
View file

@ -0,0 +1 @@
Refactor `get_user_by_id`.

View file

@ -268,7 +268,7 @@ class InternalAuth(BaseAuth):
stored_user = await self.store.get_user_by_id(user_id)
if not stored_user:
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
if not stored_user["is_guest"]:
if not stored_user.is_guest:
raise InvalidClientTokenError(
"Guest access token used for regular user"
)

View file

@ -300,7 +300,7 @@ class MSC3861DelegatedAuth(BaseAuth):
user_id = UserID(username, self._hostname)
# First try to find a user from the username claim
user_info = await self.store.get_userinfo_by_id(user_id=user_id.to_string())
user_info = await self.store.get_user_by_id(user_id=user_id.to_string())
if user_info is None:
# If the user does not exist, we should create it on the fly
# TODO: we could use SCIM to provision users ahead of time and listen

View file

@ -102,7 +102,7 @@ class AccountHandler:
"""
status = {"exists": False}
userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string())
userinfo = await self._main_store.get_user_by_id(user_id.to_string())
if userinfo is not None:
status = {

View file

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
@ -57,38 +57,30 @@ class AdminHandler:
async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details"""
user_info_dict = await self._store.get_user_by_id(user.to_string())
if user_info_dict is None:
user_info: Optional[UserInfo] = await self._store.get_user_by_id(
user.to_string()
)
if user_info is None:
return None
# Restrict returned information to a known set of fields. This prevents additional
# fields added to get_user_by_id from modifying Synapse's external API surface.
user_info_to_return = {
"name",
"admin",
"deactivated",
"locked",
"shadow_banned",
"creation_ts",
"appservice_id",
"consent_server_notice_sent",
"consent_version",
"consent_ts",
"user_type",
"is_guest",
"last_seen_ts",
user_info_dict = {
"name": user.to_string(),
"admin": user_info.is_admin,
"deactivated": user_info.is_deactivated,
"locked": user_info.locked,
"shadow_banned": user_info.is_shadow_banned,
"creation_ts": user_info.creation_ts,
"appservice_id": user_info.appservice_id,
"consent_server_notice_sent": user_info.consent_server_notice_sent,
"consent_version": user_info.consent_version,
"consent_ts": user_info.consent_ts,
"user_type": user_info.user_type,
"is_guest": user_info.is_guest,
}
if self._msc3866_enabled:
# Only include the approved flag if support for MSC3866 is enabled.
user_info_to_return.add("approved")
# Restrict returned keys to a known set.
user_info_dict = {
key: value
for key, value in user_info_dict.items()
if key in user_info_to_return
}
user_info_dict["approved"] = user_info.approved
# Add additional user metadata
profile = await self._store.get_profileinfo(user)
@ -105,6 +97,9 @@ class AdminHandler:
user_info_dict["external_ids"] = external_ids
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
last_seen_ts = await self._store.get_last_seen_for_user_id(user.to_string())
user_info_dict["last_seen_ts"] = last_seen_ts
return user_info_dict
async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:

View file

@ -828,13 +828,13 @@ class EventCreationHandler:
u = await self.store.get_user_by_id(user_id)
assert u is not None
if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
if u.user_type in (UserTypes.SUPPORT, UserTypes.BOT):
# support and bot users are not required to consent
return
if u["appservice_id"] is not None:
if u.appservice_id is not None:
# users registered by an appservice are exempt
return
if u["consent_version"] == self.config.consent.user_consent_version:
if u.consent_version == self.config.consent.user_consent_version:
return
consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)

View file

@ -572,7 +572,7 @@ class ModuleApi:
Returns:
UserInfo object if a user was found, otherwise None
"""
return await self._store.get_userinfo_by_id(user_id)
return await self._store.get_user_by_id(user_id)
async def get_user_by_req(
self,
@ -1878,7 +1878,7 @@ class AccountDataManager:
raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}")
# Ensure the user exists, so we don't just write to users that aren't there.
if await self._store.get_userinfo_by_id(user_id) is None:
if await self._store.get_user_by_id(user_id) is None:
raise ValueError(f"User {user_id} does not exist on this server.")
await self._handler.add_account_data_for_user(user_id, data_type, new_data)

View file

@ -129,7 +129,7 @@ class ConsentResource(DirectServeHtmlResource):
if u is None:
raise NotFoundError("Unknown user")
has_consented = u["consent_version"] == version
has_consented = u.consent_version == version
userhmac = userhmac_bytes.decode("ascii")
try:

View file

@ -79,15 +79,15 @@ class ConsentServerNotices:
if u is None:
return
if u["is_guest"] and not self._send_to_guests:
if u.is_guest and not self._send_to_guests:
# don't send to guests
return
if u["consent_version"] == self._current_consent_version:
if u.consent_version == self._current_consent_version:
# user has already consented
return
if u["consent_server_notice_sent"] == self._current_consent_version:
if u.consent_server_notice_sent == self._current_consent_version:
# we've already sent a notice to the user
return

View file

@ -764,3 +764,14 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
}
return list(results.values())
async def get_last_seen_for_user_id(self, user_id: str) -> Optional[int]:
"""Get the last seen timestamp for a user, if we have it."""
return await self.db_pool.simple_select_one_onecol(
table="user_ips",
keyvalues={"user_id": user_id},
retcol="MAX(last_seen)",
allow_none=True,
desc="get_last_seen_for_user_id",
)

View file

@ -16,7 +16,7 @@
import logging
import random
import re
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import attr
@ -192,8 +192,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
"""Deprecated: use get_userinfo_by_id instead"""
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Returns info about the user account, if it exists."""
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
# We could technically use simple_select_one here, but it would not perform
@ -202,16 +202,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
txn.execute(
"""
SELECT
name, password_hash, is_guest, admin, consent_version, consent_ts,
name, is_guest, admin, consent_version, consent_ts,
consent_server_notice_sent, appservice_id, creation_ts, user_type,
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
COALESCE(approved, TRUE) AS approved,
COALESCE(locked, FALSE) AS locked, last_seen_ts
COALESCE(locked, FALSE) AS locked
FROM users
LEFT JOIN (
SELECT user_id, MAX(last_seen) AS last_seen_ts
FROM user_ips GROUP BY user_id
) ls ON users.name = ls.user_id
WHERE name = ?
""",
(user_id,),
@ -228,51 +224,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_user_by_id",
func=get_user_by_id_txn,
)
if row is not None:
# If we're using SQLite our boolean values will be integers. Because we
# present some of this data as is to e.g. server admins via REST APIs, we
# want to make sure we're returning the right type of data.
# Note: when adding a column name to this list, be wary of NULLable columns,
# since NULL values will be turned into False.
boolean_columns = [
"admin",
"deactivated",
"shadow_banned",
"approved",
"locked",
]
for column in boolean_columns:
row[column] = bool(row[column])
return row
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Get a UserInfo object for a user by user ID.
Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
this method should be cached.
Args:
user_id: The user to fetch user info for.
Returns:
`UserInfo` object if user found, otherwise `None`.
"""
user_data = await self.get_user_by_id(user_id)
if not user_data:
if row is None:
return None
return UserInfo(
appservice_id=user_data["appservice_id"],
consent_server_notice_sent=user_data["consent_server_notice_sent"],
consent_version=user_data["consent_version"],
creation_ts=user_data["creation_ts"],
is_admin=bool(user_data["admin"]),
is_deactivated=bool(user_data["deactivated"]),
is_guest=bool(user_data["is_guest"]),
is_shadow_banned=bool(user_data["shadow_banned"]),
user_id=UserID.from_string(user_data["name"]),
user_type=user_data["user_type"],
last_seen_ts=user_data["last_seen_ts"],
appservice_id=row["appservice_id"],
consent_server_notice_sent=row["consent_server_notice_sent"],
consent_version=row["consent_version"],
consent_ts=row["consent_ts"],
creation_ts=row["creation_ts"],
is_admin=bool(row["admin"]),
is_deactivated=bool(row["deactivated"]),
is_guest=bool(row["is_guest"]),
is_shadow_banned=bool(row["shadow_banned"]),
user_id=UserID.from_string(row["name"]),
user_type=row["user_type"],
approved=bool(row["approved"]),
locked=bool(row["locked"]),
)
async def is_trial_user(self, user_id: str) -> bool:
@ -290,10 +258,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
now = self._clock.time_msec()
days = self.config.server.mau_appservice_trial_days.get(
info["appservice_id"], self.config.server.mau_trial_days
info.appservice_id, self.config.server.mau_trial_days
)
trial_duration_ms = days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
is_trial = (now - info.creation_ts * 1000) < trial_duration_ms
return is_trial
@cached()

View file

@ -933,33 +933,37 @@ def get_verify_key_from_cross_signing_key(
@attr.s(auto_attribs=True, frozen=True, slots=True)
class UserInfo:
"""Holds information about a user. Result of get_userinfo_by_id.
"""Holds information about a user. Result of get_user_by_id.
Attributes:
user_id: ID of the user.
appservice_id: Application service ID that created this user.
consent_server_notice_sent: Version of policy documents the user has been sent.
consent_version: Version of policy documents the user has consented to.
consent_ts: Time the user consented
creation_ts: Creation timestamp of the user.
is_admin: True if the user is an admin.
is_deactivated: True if the user has been deactivated.
is_guest: True if the user is a guest user.
is_shadow_banned: True if the user has been shadow-banned.
user_type: User type (None for normal user, 'support' and 'bot' other options).
last_seen_ts: Last activity timestamp of the user.
approved: If the user has been "approved" to register on the server.
locked: Whether the user's account has been locked
"""
user_id: UserID
appservice_id: Optional[int]
consent_server_notice_sent: Optional[str]
consent_version: Optional[str]
consent_ts: Optional[int]
user_type: Optional[str]
creation_ts: int
is_admin: bool
is_deactivated: bool
is_guest: bool
is_shadow_banned: bool
last_seen_ts: Optional[int]
approved: bool
locked: bool
class UserProfile(TypedDict):

View file

@ -188,8 +188,11 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value.
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
class FakeUserInfo:
is_guest = False
self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={})
@ -341,7 +344,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
def test_get_guest_user_from_macaroon(self) -> None:
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True})
class FakeUserInfo:
is_guest = True
self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None)
user_id = "@baldrick:matrix.org"

View file

@ -16,7 +16,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID
from synapse.types import JsonDict, UserID, UserInfo
from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config
@ -35,24 +35,22 @@ class RegistrationStoreTestCase(HomeserverTestCase):
self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.assertEqual(
{
UserInfo(
# TODO(paul): Surely this field should be 'user_id', not 'name'
"name": self.user_id,
"password_hash": self.pwhash,
"admin": 0,
"is_guest": 0,
"consent_version": None,
"consent_ts": None,
"consent_server_notice_sent": None,
"appservice_id": None,
"creation_ts": 0,
"user_type": None,
"deactivated": 0,
"locked": 0,
"shadow_banned": 0,
"approved": 1,
"last_seen_ts": None,
},
user_id=UserID.from_string(self.user_id),
is_admin=False,
is_guest=False,
consent_server_notice_sent=None,
consent_ts=None,
consent_version=None,
appservice_id=None,
creation_ts=0,
user_type=None,
is_deactivated=False,
locked=False,
is_shadow_banned=False,
approved=True,
),
(self.get_success(self.store.get_user_by_id(self.user_id))),
)
@ -65,9 +63,11 @@ class RegistrationStoreTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user
self.assertEqual(user["consent_version"], "1")
self.assertGreater(user["consent_ts"], before_consent)
self.assertLess(user["consent_ts"], self.clock.time_msec())
self.assertEqual(user.consent_version, "1")
self.assertIsNotNone(user.consent_ts)
assert user.consent_ts is not None
self.assertGreater(user.consent_ts, before_consent)
self.assertLess(user.consent_ts, self.clock.time_msec())
def test_add_tokens(self) -> None:
self.get_success(self.store.register_user(self.user_id, self.pwhash))
@ -215,7 +215,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None
self.assertTrue(user["approved"])
self.assertTrue(user.approved)
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved)
@ -228,7 +228,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None
self.assertFalse(user["approved"])
self.assertFalse(user.approved)
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertFalse(approved)
@ -248,7 +248,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id))
self.assertIsNotNone(user)
assert user is not None
self.assertEqual(user["approved"], 1)
self.assertEqual(user.approved, 1)
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved)