mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-25 11:05:49 +03:00
Refactor get_user_by_id
(#16316)
This commit is contained in:
parent
032cf84f52
commit
954921736b
14 changed files with 108 additions and 123 deletions
1
changelog.d/16316.misc
Normal file
1
changelog.d/16316.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Refactor `get_user_by_id`.
|
|
@ -268,7 +268,7 @@ class InternalAuth(BaseAuth):
|
||||||
stored_user = await self.store.get_user_by_id(user_id)
|
stored_user = await self.store.get_user_by_id(user_id)
|
||||||
if not stored_user:
|
if not stored_user:
|
||||||
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
|
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
|
||||||
if not stored_user["is_guest"]:
|
if not stored_user.is_guest:
|
||||||
raise InvalidClientTokenError(
|
raise InvalidClientTokenError(
|
||||||
"Guest access token used for regular user"
|
"Guest access token used for regular user"
|
||||||
)
|
)
|
||||||
|
|
|
@ -300,7 +300,7 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||||
user_id = UserID(username, self._hostname)
|
user_id = UserID(username, self._hostname)
|
||||||
|
|
||||||
# First try to find a user from the username claim
|
# First try to find a user from the username claim
|
||||||
user_info = await self.store.get_userinfo_by_id(user_id=user_id.to_string())
|
user_info = await self.store.get_user_by_id(user_id=user_id.to_string())
|
||||||
if user_info is None:
|
if user_info is None:
|
||||||
# If the user does not exist, we should create it on the fly
|
# If the user does not exist, we should create it on the fly
|
||||||
# TODO: we could use SCIM to provision users ahead of time and listen
|
# TODO: we could use SCIM to provision users ahead of time and listen
|
||||||
|
|
|
@ -102,7 +102,7 @@ class AccountHandler:
|
||||||
"""
|
"""
|
||||||
status = {"exists": False}
|
status = {"exists": False}
|
||||||
|
|
||||||
userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string())
|
userinfo = await self._main_store.get_user_by_id(user_id.to_string())
|
||||||
|
|
||||||
if userinfo is not None:
|
if userinfo is not None:
|
||||||
status = {
|
status = {
|
||||||
|
|
|
@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
|
||||||
|
|
||||||
from synapse.api.constants import Direction, Membership
|
from synapse.api.constants import Direction, Membership
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
|
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -57,38 +57,30 @@ class AdminHandler:
|
||||||
|
|
||||||
async def get_user(self, user: UserID) -> Optional[JsonDict]:
|
async def get_user(self, user: UserID) -> Optional[JsonDict]:
|
||||||
"""Function to get user details"""
|
"""Function to get user details"""
|
||||||
user_info_dict = await self._store.get_user_by_id(user.to_string())
|
user_info: Optional[UserInfo] = await self._store.get_user_by_id(
|
||||||
if user_info_dict is None:
|
user.to_string()
|
||||||
|
)
|
||||||
|
if user_info is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Restrict returned information to a known set of fields. This prevents additional
|
user_info_dict = {
|
||||||
# fields added to get_user_by_id from modifying Synapse's external API surface.
|
"name": user.to_string(),
|
||||||
user_info_to_return = {
|
"admin": user_info.is_admin,
|
||||||
"name",
|
"deactivated": user_info.is_deactivated,
|
||||||
"admin",
|
"locked": user_info.locked,
|
||||||
"deactivated",
|
"shadow_banned": user_info.is_shadow_banned,
|
||||||
"locked",
|
"creation_ts": user_info.creation_ts,
|
||||||
"shadow_banned",
|
"appservice_id": user_info.appservice_id,
|
||||||
"creation_ts",
|
"consent_server_notice_sent": user_info.consent_server_notice_sent,
|
||||||
"appservice_id",
|
"consent_version": user_info.consent_version,
|
||||||
"consent_server_notice_sent",
|
"consent_ts": user_info.consent_ts,
|
||||||
"consent_version",
|
"user_type": user_info.user_type,
|
||||||
"consent_ts",
|
"is_guest": user_info.is_guest,
|
||||||
"user_type",
|
|
||||||
"is_guest",
|
|
||||||
"last_seen_ts",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if self._msc3866_enabled:
|
if self._msc3866_enabled:
|
||||||
# Only include the approved flag if support for MSC3866 is enabled.
|
# Only include the approved flag if support for MSC3866 is enabled.
|
||||||
user_info_to_return.add("approved")
|
user_info_dict["approved"] = user_info.approved
|
||||||
|
|
||||||
# Restrict returned keys to a known set.
|
|
||||||
user_info_dict = {
|
|
||||||
key: value
|
|
||||||
for key, value in user_info_dict.items()
|
|
||||||
if key in user_info_to_return
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add additional user metadata
|
# Add additional user metadata
|
||||||
profile = await self._store.get_profileinfo(user)
|
profile = await self._store.get_profileinfo(user)
|
||||||
|
@ -105,6 +97,9 @@ class AdminHandler:
|
||||||
user_info_dict["external_ids"] = external_ids
|
user_info_dict["external_ids"] = external_ids
|
||||||
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
|
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
|
||||||
|
|
||||||
|
last_seen_ts = await self._store.get_last_seen_for_user_id(user.to_string())
|
||||||
|
user_info_dict["last_seen_ts"] = last_seen_ts
|
||||||
|
|
||||||
return user_info_dict
|
return user_info_dict
|
||||||
|
|
||||||
async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
|
async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
|
||||||
|
|
|
@ -828,13 +828,13 @@ class EventCreationHandler:
|
||||||
|
|
||||||
u = await self.store.get_user_by_id(user_id)
|
u = await self.store.get_user_by_id(user_id)
|
||||||
assert u is not None
|
assert u is not None
|
||||||
if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
|
if u.user_type in (UserTypes.SUPPORT, UserTypes.BOT):
|
||||||
# support and bot users are not required to consent
|
# support and bot users are not required to consent
|
||||||
return
|
return
|
||||||
if u["appservice_id"] is not None:
|
if u.appservice_id is not None:
|
||||||
# users registered by an appservice are exempt
|
# users registered by an appservice are exempt
|
||||||
return
|
return
|
||||||
if u["consent_version"] == self.config.consent.user_consent_version:
|
if u.consent_version == self.config.consent.user_consent_version:
|
||||||
return
|
return
|
||||||
|
|
||||||
consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
|
consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
|
||||||
|
|
|
@ -572,7 +572,7 @@ class ModuleApi:
|
||||||
Returns:
|
Returns:
|
||||||
UserInfo object if a user was found, otherwise None
|
UserInfo object if a user was found, otherwise None
|
||||||
"""
|
"""
|
||||||
return await self._store.get_userinfo_by_id(user_id)
|
return await self._store.get_user_by_id(user_id)
|
||||||
|
|
||||||
async def get_user_by_req(
|
async def get_user_by_req(
|
||||||
self,
|
self,
|
||||||
|
@ -1878,7 +1878,7 @@ class AccountDataManager:
|
||||||
raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}")
|
raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}")
|
||||||
|
|
||||||
# Ensure the user exists, so we don't just write to users that aren't there.
|
# Ensure the user exists, so we don't just write to users that aren't there.
|
||||||
if await self._store.get_userinfo_by_id(user_id) is None:
|
if await self._store.get_user_by_id(user_id) is None:
|
||||||
raise ValueError(f"User {user_id} does not exist on this server.")
|
raise ValueError(f"User {user_id} does not exist on this server.")
|
||||||
|
|
||||||
await self._handler.add_account_data_for_user(user_id, data_type, new_data)
|
await self._handler.add_account_data_for_user(user_id, data_type, new_data)
|
||||||
|
|
|
@ -129,7 +129,7 @@ class ConsentResource(DirectServeHtmlResource):
|
||||||
if u is None:
|
if u is None:
|
||||||
raise NotFoundError("Unknown user")
|
raise NotFoundError("Unknown user")
|
||||||
|
|
||||||
has_consented = u["consent_version"] == version
|
has_consented = u.consent_version == version
|
||||||
userhmac = userhmac_bytes.decode("ascii")
|
userhmac = userhmac_bytes.decode("ascii")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -79,15 +79,15 @@ class ConsentServerNotices:
|
||||||
if u is None:
|
if u is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if u["is_guest"] and not self._send_to_guests:
|
if u.is_guest and not self._send_to_guests:
|
||||||
# don't send to guests
|
# don't send to guests
|
||||||
return
|
return
|
||||||
|
|
||||||
if u["consent_version"] == self._current_consent_version:
|
if u.consent_version == self._current_consent_version:
|
||||||
# user has already consented
|
# user has already consented
|
||||||
return
|
return
|
||||||
|
|
||||||
if u["consent_server_notice_sent"] == self._current_consent_version:
|
if u.consent_server_notice_sent == self._current_consent_version:
|
||||||
# we've already sent a notice to the user
|
# we've already sent a notice to the user
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -764,3 +764,14 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
|
||||||
}
|
}
|
||||||
|
|
||||||
return list(results.values())
|
return list(results.values())
|
||||||
|
|
||||||
|
async def get_last_seen_for_user_id(self, user_id: str) -> Optional[int]:
|
||||||
|
"""Get the last seen timestamp for a user, if we have it."""
|
||||||
|
|
||||||
|
return await self.db_pool.simple_select_one_onecol(
|
||||||
|
table="user_ips",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
retcol="MAX(last_seen)",
|
||||||
|
allow_none=True,
|
||||||
|
desc="get_last_seen_for_user_id",
|
||||||
|
)
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -192,8 +192,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
|
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
|
||||||
"""Deprecated: use get_userinfo_by_id instead"""
|
"""Returns info about the user account, if it exists."""
|
||||||
|
|
||||||
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
|
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
|
||||||
# We could technically use simple_select_one here, but it would not perform
|
# We could technically use simple_select_one here, but it would not perform
|
||||||
|
@ -202,16 +202,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT
|
SELECT
|
||||||
name, password_hash, is_guest, admin, consent_version, consent_ts,
|
name, is_guest, admin, consent_version, consent_ts,
|
||||||
consent_server_notice_sent, appservice_id, creation_ts, user_type,
|
consent_server_notice_sent, appservice_id, creation_ts, user_type,
|
||||||
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
|
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
|
||||||
COALESCE(approved, TRUE) AS approved,
|
COALESCE(approved, TRUE) AS approved,
|
||||||
COALESCE(locked, FALSE) AS locked, last_seen_ts
|
COALESCE(locked, FALSE) AS locked
|
||||||
FROM users
|
FROM users
|
||||||
LEFT JOIN (
|
|
||||||
SELECT user_id, MAX(last_seen) AS last_seen_ts
|
|
||||||
FROM user_ips GROUP BY user_id
|
|
||||||
) ls ON users.name = ls.user_id
|
|
||||||
WHERE name = ?
|
WHERE name = ?
|
||||||
""",
|
""",
|
||||||
(user_id,),
|
(user_id,),
|
||||||
|
@ -228,51 +224,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
desc="get_user_by_id",
|
desc="get_user_by_id",
|
||||||
func=get_user_by_id_txn,
|
func=get_user_by_id_txn,
|
||||||
)
|
)
|
||||||
|
if row is None:
|
||||||
if row is not None:
|
|
||||||
# If we're using SQLite our boolean values will be integers. Because we
|
|
||||||
# present some of this data as is to e.g. server admins via REST APIs, we
|
|
||||||
# want to make sure we're returning the right type of data.
|
|
||||||
# Note: when adding a column name to this list, be wary of NULLable columns,
|
|
||||||
# since NULL values will be turned into False.
|
|
||||||
boolean_columns = [
|
|
||||||
"admin",
|
|
||||||
"deactivated",
|
|
||||||
"shadow_banned",
|
|
||||||
"approved",
|
|
||||||
"locked",
|
|
||||||
]
|
|
||||||
for column in boolean_columns:
|
|
||||||
row[column] = bool(row[column])
|
|
||||||
|
|
||||||
return row
|
|
||||||
|
|
||||||
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
|
|
||||||
"""Get a UserInfo object for a user by user ID.
|
|
||||||
|
|
||||||
Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
|
|
||||||
this method should be cached.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: The user to fetch user info for.
|
|
||||||
Returns:
|
|
||||||
`UserInfo` object if user found, otherwise `None`.
|
|
||||||
"""
|
|
||||||
user_data = await self.get_user_by_id(user_id)
|
|
||||||
if not user_data:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return UserInfo(
|
return UserInfo(
|
||||||
appservice_id=user_data["appservice_id"],
|
appservice_id=row["appservice_id"],
|
||||||
consent_server_notice_sent=user_data["consent_server_notice_sent"],
|
consent_server_notice_sent=row["consent_server_notice_sent"],
|
||||||
consent_version=user_data["consent_version"],
|
consent_version=row["consent_version"],
|
||||||
creation_ts=user_data["creation_ts"],
|
consent_ts=row["consent_ts"],
|
||||||
is_admin=bool(user_data["admin"]),
|
creation_ts=row["creation_ts"],
|
||||||
is_deactivated=bool(user_data["deactivated"]),
|
is_admin=bool(row["admin"]),
|
||||||
is_guest=bool(user_data["is_guest"]),
|
is_deactivated=bool(row["deactivated"]),
|
||||||
is_shadow_banned=bool(user_data["shadow_banned"]),
|
is_guest=bool(row["is_guest"]),
|
||||||
user_id=UserID.from_string(user_data["name"]),
|
is_shadow_banned=bool(row["shadow_banned"]),
|
||||||
user_type=user_data["user_type"],
|
user_id=UserID.from_string(row["name"]),
|
||||||
last_seen_ts=user_data["last_seen_ts"],
|
user_type=row["user_type"],
|
||||||
|
approved=bool(row["approved"]),
|
||||||
|
locked=bool(row["locked"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def is_trial_user(self, user_id: str) -> bool:
|
async def is_trial_user(self, user_id: str) -> bool:
|
||||||
|
@ -290,10 +258,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
now = self._clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
days = self.config.server.mau_appservice_trial_days.get(
|
days = self.config.server.mau_appservice_trial_days.get(
|
||||||
info["appservice_id"], self.config.server.mau_trial_days
|
info.appservice_id, self.config.server.mau_trial_days
|
||||||
)
|
)
|
||||||
trial_duration_ms = days * 24 * 60 * 60 * 1000
|
trial_duration_ms = days * 24 * 60 * 60 * 1000
|
||||||
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
|
is_trial = (now - info.creation_ts * 1000) < trial_duration_ms
|
||||||
return is_trial
|
return is_trial
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
|
|
|
@ -933,33 +933,37 @@ def get_verify_key_from_cross_signing_key(
|
||||||
|
|
||||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
||||||
class UserInfo:
|
class UserInfo:
|
||||||
"""Holds information about a user. Result of get_userinfo_by_id.
|
"""Holds information about a user. Result of get_user_by_id.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
user_id: ID of the user.
|
user_id: ID of the user.
|
||||||
appservice_id: Application service ID that created this user.
|
appservice_id: Application service ID that created this user.
|
||||||
consent_server_notice_sent: Version of policy documents the user has been sent.
|
consent_server_notice_sent: Version of policy documents the user has been sent.
|
||||||
consent_version: Version of policy documents the user has consented to.
|
consent_version: Version of policy documents the user has consented to.
|
||||||
|
consent_ts: Time the user consented
|
||||||
creation_ts: Creation timestamp of the user.
|
creation_ts: Creation timestamp of the user.
|
||||||
is_admin: True if the user is an admin.
|
is_admin: True if the user is an admin.
|
||||||
is_deactivated: True if the user has been deactivated.
|
is_deactivated: True if the user has been deactivated.
|
||||||
is_guest: True if the user is a guest user.
|
is_guest: True if the user is a guest user.
|
||||||
is_shadow_banned: True if the user has been shadow-banned.
|
is_shadow_banned: True if the user has been shadow-banned.
|
||||||
user_type: User type (None for normal user, 'support' and 'bot' other options).
|
user_type: User type (None for normal user, 'support' and 'bot' other options).
|
||||||
last_seen_ts: Last activity timestamp of the user.
|
approved: If the user has been "approved" to register on the server.
|
||||||
|
locked: Whether the user's account has been locked
|
||||||
"""
|
"""
|
||||||
|
|
||||||
user_id: UserID
|
user_id: UserID
|
||||||
appservice_id: Optional[int]
|
appservice_id: Optional[int]
|
||||||
consent_server_notice_sent: Optional[str]
|
consent_server_notice_sent: Optional[str]
|
||||||
consent_version: Optional[str]
|
consent_version: Optional[str]
|
||||||
|
consent_ts: Optional[int]
|
||||||
user_type: Optional[str]
|
user_type: Optional[str]
|
||||||
creation_ts: int
|
creation_ts: int
|
||||||
is_admin: bool
|
is_admin: bool
|
||||||
is_deactivated: bool
|
is_deactivated: bool
|
||||||
is_guest: bool
|
is_guest: bool
|
||||||
is_shadow_banned: bool
|
is_shadow_banned: bool
|
||||||
last_seen_ts: Optional[int]
|
approved: bool
|
||||||
|
locked: bool
|
||||||
|
|
||||||
|
|
||||||
class UserProfile(TypedDict):
|
class UserProfile(TypedDict):
|
||||||
|
|
|
@ -188,8 +188,11 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
app_service.is_interested_in_user = Mock(return_value=True)
|
app_service.is_interested_in_user = Mock(return_value=True)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
# This just needs to return a truth-y value.
|
|
||||||
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
|
class FakeUserInfo:
|
||||||
|
is_guest = False
|
||||||
|
|
||||||
|
self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
|
||||||
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
|
@ -341,7 +344,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_guest_user_from_macaroon(self) -> None:
|
def test_get_guest_user_from_macaroon(self) -> None:
|
||||||
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True})
|
class FakeUserInfo:
|
||||||
|
is_guest = True
|
||||||
|
|
||||||
|
self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
|
||||||
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||||
|
|
||||||
user_id = "@baldrick:matrix.org"
|
user_id = "@baldrick:matrix.org"
|
||||||
|
|
|
@ -16,7 +16,7 @@ from twisted.test.proto_helpers import MemoryReactor
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import ThreepidValidationError
|
from synapse.api.errors import ThreepidValidationError
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, UserID, UserInfo
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase, override_config
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
@ -35,24 +35,22 @@ class RegistrationStoreTestCase(HomeserverTestCase):
|
||||||
self.get_success(self.store.register_user(self.user_id, self.pwhash))
|
self.get_success(self.store.register_user(self.user_id, self.pwhash))
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
{
|
UserInfo(
|
||||||
# TODO(paul): Surely this field should be 'user_id', not 'name'
|
# TODO(paul): Surely this field should be 'user_id', not 'name'
|
||||||
"name": self.user_id,
|
user_id=UserID.from_string(self.user_id),
|
||||||
"password_hash": self.pwhash,
|
is_admin=False,
|
||||||
"admin": 0,
|
is_guest=False,
|
||||||
"is_guest": 0,
|
consent_server_notice_sent=None,
|
||||||
"consent_version": None,
|
consent_ts=None,
|
||||||
"consent_ts": None,
|
consent_version=None,
|
||||||
"consent_server_notice_sent": None,
|
appservice_id=None,
|
||||||
"appservice_id": None,
|
creation_ts=0,
|
||||||
"creation_ts": 0,
|
user_type=None,
|
||||||
"user_type": None,
|
is_deactivated=False,
|
||||||
"deactivated": 0,
|
locked=False,
|
||||||
"locked": 0,
|
is_shadow_banned=False,
|
||||||
"shadow_banned": 0,
|
approved=True,
|
||||||
"approved": 1,
|
),
|
||||||
"last_seen_ts": None,
|
|
||||||
},
|
|
||||||
(self.get_success(self.store.get_user_by_id(self.user_id))),
|
(self.get_success(self.store.get_user_by_id(self.user_id))),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -65,9 +63,11 @@ class RegistrationStoreTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||||
assert user
|
assert user
|
||||||
self.assertEqual(user["consent_version"], "1")
|
self.assertEqual(user.consent_version, "1")
|
||||||
self.assertGreater(user["consent_ts"], before_consent)
|
self.assertIsNotNone(user.consent_ts)
|
||||||
self.assertLess(user["consent_ts"], self.clock.time_msec())
|
assert user.consent_ts is not None
|
||||||
|
self.assertGreater(user.consent_ts, before_consent)
|
||||||
|
self.assertLess(user.consent_ts, self.clock.time_msec())
|
||||||
|
|
||||||
def test_add_tokens(self) -> None:
|
def test_add_tokens(self) -> None:
|
||||||
self.get_success(self.store.register_user(self.user_id, self.pwhash))
|
self.get_success(self.store.register_user(self.user_id, self.pwhash))
|
||||||
|
@ -215,7 +215,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||||
assert user is not None
|
assert user is not None
|
||||||
self.assertTrue(user["approved"])
|
self.assertTrue(user.approved)
|
||||||
|
|
||||||
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||||
self.assertTrue(approved)
|
self.assertTrue(approved)
|
||||||
|
@ -228,7 +228,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||||
assert user is not None
|
assert user is not None
|
||||||
self.assertFalse(user["approved"])
|
self.assertFalse(user.approved)
|
||||||
|
|
||||||
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||||
self.assertFalse(approved)
|
self.assertFalse(approved)
|
||||||
|
@ -248,7 +248,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
|
||||||
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||||
self.assertIsNotNone(user)
|
self.assertIsNotNone(user)
|
||||||
assert user is not None
|
assert user is not None
|
||||||
self.assertEqual(user["approved"], 1)
|
self.assertEqual(user.approved, 1)
|
||||||
|
|
||||||
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||||
self.assertTrue(approved)
|
self.assertTrue(approved)
|
||||||
|
|
Loading…
Reference in a new issue