forward requester id to check username for spam callbacks (#17916)

This commit is contained in:
Wilson 2024-12-13 09:17:41 -05:00 committed by GitHub
parent 483602efb2
commit eedab12e6d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 70 additions and 7 deletions

View file

@ -0,0 +1 @@
Module developers will have access to user id of requester when adding `check_username_for_spam` callbacks to `spam_checker_module_callbacks`. Contributed by Wilson@Pangea.chat.

View file

@ -245,7 +245,7 @@ this callback.
_First introduced in Synapse v1.37.0_ _First introduced in Synapse v1.37.0_
```python ```python
async def check_username_for_spam(user_profile: synapse.module_api.UserProfile) -> bool async def check_username_for_spam(user_profile: synapse.module_api.UserProfile, requester_id: str) -> bool
``` ```
Called when computing search results in the user directory. The module must return a Called when computing search results in the user directory. The module must return a
@ -264,6 +264,8 @@ The profile is represented as a dictionary with the following keys:
The module is given a copy of the original dictionary, so modifying it from within the The module is given a copy of the original dictionary, so modifying it from within the
module cannot modify a user's profile when included in user directory search results. module cannot modify a user's profile when included in user directory search results.
The requester_id parameter is the ID of the user that called the user directory API.
If multiple modules implement this callback, they will be considered in order. If a If multiple modules implement this callback, they will be considered in order. If a
callback returns `False`, Synapse falls through to the next one. The value of the first callback returns `False`, Synapse falls through to the next one. The value of the first
callback that does not return `False` will be used. If this happens, Synapse will not call callback that does not return `False` will be used. If this happens, Synapse will not call

View file

@ -72,8 +72,8 @@ class ExampleSpamChecker:
async def user_may_publish_room(self, userid, room_id): async def user_may_publish_room(self, userid, room_id):
return True # allow publishing of all rooms return True # allow publishing of all rooms
async def check_username_for_spam(self, user_profile): async def check_username_for_spam(self, user_profile, requester_id):
return False # allow all usernames return False # allow all usernames regardless of requester
async def check_registration_for_spam( async def check_registration_for_spam(
self, self,

View file

@ -161,7 +161,7 @@ class UserDirectoryHandler(StateDeltasHandler):
non_spammy_users = [] non_spammy_users = []
for user in results["results"]: for user in results["results"]:
if not await self._spam_checker_module_callbacks.check_username_for_spam( if not await self._spam_checker_module_callbacks.check_username_for_spam(
user user, user_id
): ):
non_spammy_users.append(user) non_spammy_users.append(user)
results["results"] = non_spammy_users results["results"] = non_spammy_users

View file

@ -31,6 +31,7 @@ from typing import (
Optional, Optional,
Tuple, Tuple,
Union, Union,
cast,
) )
# `Literal` appears with Python 3.8. # `Literal` appears with Python 3.8.
@ -168,7 +169,10 @@ USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[
] ]
], ],
] ]
CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]] CHECK_USERNAME_FOR_SPAM_CALLBACK = Union[
Callable[[UserProfile], Awaitable[bool]],
Callable[[UserProfile, str], Awaitable[bool]],
]
LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[ LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
[ [
Optional[dict], Optional[dict],
@ -716,7 +720,9 @@ class SpamCheckerModuleApiCallbacks:
return self.NOT_SPAM return self.NOT_SPAM
async def check_username_for_spam(self, user_profile: UserProfile) -> bool: async def check_username_for_spam(
self, user_profile: UserProfile, requester_id: str
) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server. """Checks if a user ID or display name are considered "spammy" by this server.
If the server considers a username spammy, then it will not be included in If the server considers a username spammy, then it will not be included in
@ -727,15 +733,33 @@ class SpamCheckerModuleApiCallbacks:
* user_id * user_id
* display_name * display_name
* avatar_url * avatar_url
requester_id: The user ID of the user making the user directory search request.
Returns: Returns:
True if the user is spammy. True if the user is spammy.
""" """
for callback in self._check_username_for_spam_callbacks: for callback in self._check_username_for_spam_callbacks:
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
checker_args = inspect.signature(callback)
# Make a copy of the user profile object to ensure the spam checker cannot # Make a copy of the user profile object to ensure the spam checker cannot
# modify it. # modify it.
res = await delay_cancellation(callback(user_profile.copy())) # Also ensure backwards compatibility with spam checker callbacks
# that don't expect the requester_id argument.
if len(checker_args.parameters) == 2:
callback_with_requester_id = cast(
Callable[[UserProfile, str], Awaitable[bool]], callback
)
res = await delay_cancellation(
callback_with_requester_id(user_profile.copy(), requester_id)
)
else:
callback_without_requester_id = cast(
Callable[[UserProfile], Awaitable[bool]], callback
)
res = await delay_cancellation(
callback_without_requester_id(user_profile.copy())
)
if res: if res:
return True return True

View file

@ -796,6 +796,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10)) s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1) self.assertEqual(len(s["results"]), 1)
# Kept old spam checker without `requester_id` tests for backwards compatibility.
async def allow_all(user_profile: UserProfile) -> bool: async def allow_all(user_profile: UserProfile) -> bool:
# Allow all users. # Allow all users.
return False return False
@ -809,6 +810,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10)) s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1) self.assertEqual(len(s["results"]), 1)
# Kept old spam checker without `requester_id` tests for backwards compatibility.
# Configure a spam checker that filters all users. # Configure a spam checker that filters all users.
async def block_all(user_profile: UserProfile) -> bool: async def block_all(user_profile: UserProfile) -> bool:
# All users are spammy. # All users are spammy.
@ -820,6 +822,40 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10)) s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 0) self.assertEqual(len(s["results"]), 0)
async def allow_all_expects_requester_id(
user_profile: UserProfile, requester_id: str
) -> bool:
self.assertEqual(requester_id, u1)
# Allow all users.
return False
# Configure a spam checker that does not filter any users.
spam_checker = self.hs.get_module_api_callbacks().spam_checker
spam_checker._check_username_for_spam_callbacks = [
allow_all_expects_requester_id
]
# The results do not change:
# We get one search result when searching for user2 by user1.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
async def block_all_expects_requester_id(
user_profile: UserProfile, requester_id: str
) -> bool:
self.assertEqual(requester_id, u1)
# All users are spammy.
return True
spam_checker._check_username_for_spam_callbacks = [
block_all_expects_requester_id
]
# User1 now gets no search results for any of the other users.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 0)
@override_config( @override_config(
{ {
"spam_checker": { "spam_checker": {