mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-22 01:25:44 +03:00
Refactor get_user_by_id
(#16316)
This commit is contained in:
parent
032cf84f52
commit
954921736b
14 changed files with 108 additions and 123 deletions
1
changelog.d/16316.misc
Normal file
1
changelog.d/16316.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Refactor `get_user_by_id`.
|
|
@ -268,7 +268,7 @@ class InternalAuth(BaseAuth):
|
|||
stored_user = await self.store.get_user_by_id(user_id)
|
||||
if not stored_user:
|
||||
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
|
||||
if not stored_user["is_guest"]:
|
||||
if not stored_user.is_guest:
|
||||
raise InvalidClientTokenError(
|
||||
"Guest access token used for regular user"
|
||||
)
|
||||
|
|
|
@ -300,7 +300,7 @@ class MSC3861DelegatedAuth(BaseAuth):
|
|||
user_id = UserID(username, self._hostname)
|
||||
|
||||
# First try to find a user from the username claim
|
||||
user_info = await self.store.get_userinfo_by_id(user_id=user_id.to_string())
|
||||
user_info = await self.store.get_user_by_id(user_id=user_id.to_string())
|
||||
if user_info is None:
|
||||
# If the user does not exist, we should create it on the fly
|
||||
# TODO: we could use SCIM to provision users ahead of time and listen
|
||||
|
|
|
@ -102,7 +102,7 @@ class AccountHandler:
|
|||
"""
|
||||
status = {"exists": False}
|
||||
|
||||
userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string())
|
||||
userinfo = await self._main_store.get_user_by_id(user_id.to_string())
|
||||
|
||||
if userinfo is not None:
|
||||
status = {
|
||||
|
|
|
@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
|
|||
|
||||
from synapse.api.constants import Direction, Membership
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
|
||||
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -57,38 +57,30 @@ class AdminHandler:
|
|||
|
||||
async def get_user(self, user: UserID) -> Optional[JsonDict]:
|
||||
"""Function to get user details"""
|
||||
user_info_dict = await self._store.get_user_by_id(user.to_string())
|
||||
if user_info_dict is None:
|
||||
user_info: Optional[UserInfo] = await self._store.get_user_by_id(
|
||||
user.to_string()
|
||||
)
|
||||
if user_info is None:
|
||||
return None
|
||||
|
||||
# Restrict returned information to a known set of fields. This prevents additional
|
||||
# fields added to get_user_by_id from modifying Synapse's external API surface.
|
||||
user_info_to_return = {
|
||||
"name",
|
||||
"admin",
|
||||
"deactivated",
|
||||
"locked",
|
||||
"shadow_banned",
|
||||
"creation_ts",
|
||||
"appservice_id",
|
||||
"consent_server_notice_sent",
|
||||
"consent_version",
|
||||
"consent_ts",
|
||||
"user_type",
|
||||
"is_guest",
|
||||
"last_seen_ts",
|
||||
user_info_dict = {
|
||||
"name": user.to_string(),
|
||||
"admin": user_info.is_admin,
|
||||
"deactivated": user_info.is_deactivated,
|
||||
"locked": user_info.locked,
|
||||
"shadow_banned": user_info.is_shadow_banned,
|
||||
"creation_ts": user_info.creation_ts,
|
||||
"appservice_id": user_info.appservice_id,
|
||||
"consent_server_notice_sent": user_info.consent_server_notice_sent,
|
||||
"consent_version": user_info.consent_version,
|
||||
"consent_ts": user_info.consent_ts,
|
||||
"user_type": user_info.user_type,
|
||||
"is_guest": user_info.is_guest,
|
||||
}
|
||||
|
||||
if self._msc3866_enabled:
|
||||
# Only include the approved flag if support for MSC3866 is enabled.
|
||||
user_info_to_return.add("approved")
|
||||
|
||||
# Restrict returned keys to a known set.
|
||||
user_info_dict = {
|
||||
key: value
|
||||
for key, value in user_info_dict.items()
|
||||
if key in user_info_to_return
|
||||
}
|
||||
user_info_dict["approved"] = user_info.approved
|
||||
|
||||
# Add additional user metadata
|
||||
profile = await self._store.get_profileinfo(user)
|
||||
|
@ -105,6 +97,9 @@ class AdminHandler:
|
|||
user_info_dict["external_ids"] = external_ids
|
||||
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
|
||||
|
||||
last_seen_ts = await self._store.get_last_seen_for_user_id(user.to_string())
|
||||
user_info_dict["last_seen_ts"] = last_seen_ts
|
||||
|
||||
return user_info_dict
|
||||
|
||||
async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
|
||||
|
|
|
@ -828,13 +828,13 @@ class EventCreationHandler:
|
|||
|
||||
u = await self.store.get_user_by_id(user_id)
|
||||
assert u is not None
|
||||
if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
|
||||
if u.user_type in (UserTypes.SUPPORT, UserTypes.BOT):
|
||||
# support and bot users are not required to consent
|
||||
return
|
||||
if u["appservice_id"] is not None:
|
||||
if u.appservice_id is not None:
|
||||
# users registered by an appservice are exempt
|
||||
return
|
||||
if u["consent_version"] == self.config.consent.user_consent_version:
|
||||
if u.consent_version == self.config.consent.user_consent_version:
|
||||
return
|
||||
|
||||
consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
|
||||
|
|
|
@ -572,7 +572,7 @@ class ModuleApi:
|
|||
Returns:
|
||||
UserInfo object if a user was found, otherwise None
|
||||
"""
|
||||
return await self._store.get_userinfo_by_id(user_id)
|
||||
return await self._store.get_user_by_id(user_id)
|
||||
|
||||
async def get_user_by_req(
|
||||
self,
|
||||
|
@ -1878,7 +1878,7 @@ class AccountDataManager:
|
|||
raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}")
|
||||
|
||||
# Ensure the user exists, so we don't just write to users that aren't there.
|
||||
if await self._store.get_userinfo_by_id(user_id) is None:
|
||||
if await self._store.get_user_by_id(user_id) is None:
|
||||
raise ValueError(f"User {user_id} does not exist on this server.")
|
||||
|
||||
await self._handler.add_account_data_for_user(user_id, data_type, new_data)
|
||||
|
|
|
@ -129,7 +129,7 @@ class ConsentResource(DirectServeHtmlResource):
|
|||
if u is None:
|
||||
raise NotFoundError("Unknown user")
|
||||
|
||||
has_consented = u["consent_version"] == version
|
||||
has_consented = u.consent_version == version
|
||||
userhmac = userhmac_bytes.decode("ascii")
|
||||
|
||||
try:
|
||||
|
|
|
@ -79,15 +79,15 @@ class ConsentServerNotices:
|
|||
if u is None:
|
||||
return
|
||||
|
||||
if u["is_guest"] and not self._send_to_guests:
|
||||
if u.is_guest and not self._send_to_guests:
|
||||
# don't send to guests
|
||||
return
|
||||
|
||||
if u["consent_version"] == self._current_consent_version:
|
||||
if u.consent_version == self._current_consent_version:
|
||||
# user has already consented
|
||||
return
|
||||
|
||||
if u["consent_server_notice_sent"] == self._current_consent_version:
|
||||
if u.consent_server_notice_sent == self._current_consent_version:
|
||||
# we've already sent a notice to the user
|
||||
return
|
||||
|
||||
|
|
|
@ -764,3 +764,14 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
|
|||
}
|
||||
|
||||
return list(results.values())
|
||||
|
||||
async def get_last_seen_for_user_id(self, user_id: str) -> Optional[int]:
|
||||
"""Get the last seen timestamp for a user, if we have it."""
|
||||
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="user_ips",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="MAX(last_seen)",
|
||||
allow_none=True,
|
||||
desc="get_last_seen_for_user_id",
|
||||
)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import logging
|
||||
import random
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -192,8 +192,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
)
|
||||
|
||||
@cached()
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
|
||||
"""Deprecated: use get_userinfo_by_id instead"""
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
|
||||
"""Returns info about the user account, if it exists."""
|
||||
|
||||
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
|
||||
# We could technically use simple_select_one here, but it would not perform
|
||||
|
@ -202,16 +202,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
txn.execute(
|
||||
"""
|
||||
SELECT
|
||||
name, password_hash, is_guest, admin, consent_version, consent_ts,
|
||||
name, is_guest, admin, consent_version, consent_ts,
|
||||
consent_server_notice_sent, appservice_id, creation_ts, user_type,
|
||||
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
|
||||
COALESCE(approved, TRUE) AS approved,
|
||||
COALESCE(locked, FALSE) AS locked, last_seen_ts
|
||||
COALESCE(locked, FALSE) AS locked
|
||||
FROM users
|
||||
LEFT JOIN (
|
||||
SELECT user_id, MAX(last_seen) AS last_seen_ts
|
||||
FROM user_ips GROUP BY user_id
|
||||
) ls ON users.name = ls.user_id
|
||||
WHERE name = ?
|
||||
""",
|
||||
(user_id,),
|
||||
|
@ -228,51 +224,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
desc="get_user_by_id",
|
||||
func=get_user_by_id_txn,
|
||||
)
|
||||
|
||||
if row is not None:
|
||||
# If we're using SQLite our boolean values will be integers. Because we
|
||||
# present some of this data as is to e.g. server admins via REST APIs, we
|
||||
# want to make sure we're returning the right type of data.
|
||||
# Note: when adding a column name to this list, be wary of NULLable columns,
|
||||
# since NULL values will be turned into False.
|
||||
boolean_columns = [
|
||||
"admin",
|
||||
"deactivated",
|
||||
"shadow_banned",
|
||||
"approved",
|
||||
"locked",
|
||||
]
|
||||
for column in boolean_columns:
|
||||
row[column] = bool(row[column])
|
||||
|
||||
return row
|
||||
|
||||
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
|
||||
"""Get a UserInfo object for a user by user ID.
|
||||
|
||||
Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
|
||||
this method should be cached.
|
||||
|
||||
Args:
|
||||
user_id: The user to fetch user info for.
|
||||
Returns:
|
||||
`UserInfo` object if user found, otherwise `None`.
|
||||
"""
|
||||
user_data = await self.get_user_by_id(user_id)
|
||||
if not user_data:
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return UserInfo(
|
||||
appservice_id=user_data["appservice_id"],
|
||||
consent_server_notice_sent=user_data["consent_server_notice_sent"],
|
||||
consent_version=user_data["consent_version"],
|
||||
creation_ts=user_data["creation_ts"],
|
||||
is_admin=bool(user_data["admin"]),
|
||||
is_deactivated=bool(user_data["deactivated"]),
|
||||
is_guest=bool(user_data["is_guest"]),
|
||||
is_shadow_banned=bool(user_data["shadow_banned"]),
|
||||
user_id=UserID.from_string(user_data["name"]),
|
||||
user_type=user_data["user_type"],
|
||||
last_seen_ts=user_data["last_seen_ts"],
|
||||
appservice_id=row["appservice_id"],
|
||||
consent_server_notice_sent=row["consent_server_notice_sent"],
|
||||
consent_version=row["consent_version"],
|
||||
consent_ts=row["consent_ts"],
|
||||
creation_ts=row["creation_ts"],
|
||||
is_admin=bool(row["admin"]),
|
||||
is_deactivated=bool(row["deactivated"]),
|
||||
is_guest=bool(row["is_guest"]),
|
||||
is_shadow_banned=bool(row["shadow_banned"]),
|
||||
user_id=UserID.from_string(row["name"]),
|
||||
user_type=row["user_type"],
|
||||
approved=bool(row["approved"]),
|
||||
locked=bool(row["locked"]),
|
||||
)
|
||||
|
||||
async def is_trial_user(self, user_id: str) -> bool:
|
||||
|
@ -290,10 +258,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
now = self._clock.time_msec()
|
||||
days = self.config.server.mau_appservice_trial_days.get(
|
||||
info["appservice_id"], self.config.server.mau_trial_days
|
||||
info.appservice_id, self.config.server.mau_trial_days
|
||||
)
|
||||
trial_duration_ms = days * 24 * 60 * 60 * 1000
|
||||
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
|
||||
is_trial = (now - info.creation_ts * 1000) < trial_duration_ms
|
||||
return is_trial
|
||||
|
||||
@cached()
|
||||
|
|
|
@ -933,33 +933,37 @@ def get_verify_key_from_cross_signing_key(
|
|||
|
||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
||||
class UserInfo:
|
||||
"""Holds information about a user. Result of get_userinfo_by_id.
|
||||
"""Holds information about a user. Result of get_user_by_id.
|
||||
|
||||
Attributes:
|
||||
user_id: ID of the user.
|
||||
appservice_id: Application service ID that created this user.
|
||||
consent_server_notice_sent: Version of policy documents the user has been sent.
|
||||
consent_version: Version of policy documents the user has consented to.
|
||||
consent_ts: Time the user consented
|
||||
creation_ts: Creation timestamp of the user.
|
||||
is_admin: True if the user is an admin.
|
||||
is_deactivated: True if the user has been deactivated.
|
||||
is_guest: True if the user is a guest user.
|
||||
is_shadow_banned: True if the user has been shadow-banned.
|
||||
user_type: User type (None for normal user, 'support' and 'bot' other options).
|
||||
last_seen_ts: Last activity timestamp of the user.
|
||||
approved: If the user has been "approved" to register on the server.
|
||||
locked: Whether the user's account has been locked
|
||||
"""
|
||||
|
||||
user_id: UserID
|
||||
appservice_id: Optional[int]
|
||||
consent_server_notice_sent: Optional[str]
|
||||
consent_version: Optional[str]
|
||||
consent_ts: Optional[int]
|
||||
user_type: Optional[str]
|
||||
creation_ts: int
|
||||
is_admin: bool
|
||||
is_deactivated: bool
|
||||
is_guest: bool
|
||||
is_shadow_banned: bool
|
||||
last_seen_ts: Optional[int]
|
||||
approved: bool
|
||||
locked: bool
|
||||
|
||||
|
||||
class UserProfile(TypedDict):
|
||||
|
|
|
@ -188,8 +188,11 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
app_service.is_interested_in_user = Mock(return_value=True)
|
||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||
# This just needs to return a truth-y value.
|
||||
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
|
||||
|
||||
class FakeUserInfo:
|
||||
is_guest = False
|
||||
|
||||
self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
|
||||
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||
|
||||
request = Mock(args={})
|
||||
|
@ -341,7 +344,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
def test_get_guest_user_from_macaroon(self) -> None:
|
||||
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True})
|
||||
class FakeUserInfo:
|
||||
is_guest = True
|
||||
|
||||
self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
|
||||
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||
|
||||
user_id = "@baldrick:matrix.org"
|
||||
|
|
|
@ -16,7 +16,7 @@ from twisted.test.proto_helpers import MemoryReactor
|
|||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import ThreepidValidationError
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.types import JsonDict, UserID, UserInfo
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
|
@ -35,24 +35,22 @@ class RegistrationStoreTestCase(HomeserverTestCase):
|
|||
self.get_success(self.store.register_user(self.user_id, self.pwhash))
|
||||
|
||||
self.assertEqual(
|
||||
{
|
||||
UserInfo(
|
||||
# TODO(paul): Surely this field should be 'user_id', not 'name'
|
||||
"name": self.user_id,
|
||||
"password_hash": self.pwhash,
|
||||
"admin": 0,
|
||||
"is_guest": 0,
|
||||
"consent_version": None,
|
||||
"consent_ts": None,
|
||||
"consent_server_notice_sent": None,
|
||||
"appservice_id": None,
|
||||
"creation_ts": 0,
|
||||
"user_type": None,
|
||||
"deactivated": 0,
|
||||
"locked": 0,
|
||||
"shadow_banned": 0,
|
||||
"approved": 1,
|
||||
"last_seen_ts": None,
|
||||
},
|
||||
user_id=UserID.from_string(self.user_id),
|
||||
is_admin=False,
|
||||
is_guest=False,
|
||||
consent_server_notice_sent=None,
|
||||
consent_ts=None,
|
||||
consent_version=None,
|
||||
appservice_id=None,
|
||||
creation_ts=0,
|
||||
user_type=None,
|
||||
is_deactivated=False,
|
||||
locked=False,
|
||||
is_shadow_banned=False,
|
||||
approved=True,
|
||||
),
|
||||
(self.get_success(self.store.get_user_by_id(self.user_id))),
|
||||
)
|
||||
|
||||
|
@ -65,9 +63,11 @@ class RegistrationStoreTestCase(HomeserverTestCase):
|
|||
|
||||
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||
assert user
|
||||
self.assertEqual(user["consent_version"], "1")
|
||||
self.assertGreater(user["consent_ts"], before_consent)
|
||||
self.assertLess(user["consent_ts"], self.clock.time_msec())
|
||||
self.assertEqual(user.consent_version, "1")
|
||||
self.assertIsNotNone(user.consent_ts)
|
||||
assert user.consent_ts is not None
|
||||
self.assertGreater(user.consent_ts, before_consent)
|
||||
self.assertLess(user.consent_ts, self.clock.time_msec())
|
||||
|
||||
def test_add_tokens(self) -> None:
|
||||
self.get_success(self.store.register_user(self.user_id, self.pwhash))
|
||||
|
@ -215,7 +215,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
|
|||
|
||||
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||
assert user is not None
|
||||
self.assertTrue(user["approved"])
|
||||
self.assertTrue(user.approved)
|
||||
|
||||
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||
self.assertTrue(approved)
|
||||
|
@ -228,7 +228,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
|
|||
|
||||
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||
assert user is not None
|
||||
self.assertFalse(user["approved"])
|
||||
self.assertFalse(user.approved)
|
||||
|
||||
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||
self.assertFalse(approved)
|
||||
|
@ -248,7 +248,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
|
|||
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||
self.assertIsNotNone(user)
|
||||
assert user is not None
|
||||
self.assertEqual(user["approved"], 1)
|
||||
self.assertEqual(user.approved, 1)
|
||||
|
||||
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||
self.assertTrue(approved)
|
||||
|
|
Loading…
Reference in a new issue