Convert _base, profile, and _receipts handlers to async/await (#7860)

This commit is contained in:
Patrick Cloke 2020-07-17 07:08:30 -04:00 committed by GitHub
parent fff483ea96
commit 6fca1b3506
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 53 additions and 59 deletions

1
changelog.d/7860.misc Normal file
View file

@ -0,0 +1 @@
Convert _base, profile, and _receipts handlers to async/await.

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
import synapse.state import synapse.state
import synapse.storage import synapse.storage
import synapse.types import synapse.types
@ -66,8 +64,7 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks async def ratelimit(self, requester, update=True, is_admin_redaction=False):
def ratelimit(self, requester, update=True, is_admin_redaction=False):
"""Ratelimits requests. """Ratelimits requests.
Args: Args:
@ -99,7 +96,7 @@ class BaseHandler(object):
burst_count = self._rc_message.burst_count burst_count = self._rc_message.burst_count
# Check if there is a per user override in the DB. # Check if there is a per user override in the DB.
override = yield self.store.get_ratelimit_for_user(user_id) override = await self.store.get_ratelimit_for_user(user_id)
if override: if override:
# If overridden with a null Hz then ratelimiting has been entirely # If overridden with a null Hz then ratelimiting has been entirely
# disabled for the user # disabled for the user

View file

@ -488,11 +488,15 @@ class EventCreationHandler(object):
try: try:
if "displayname" not in content: if "displayname" not in content:
displayname = yield profile.get_displayname(target) displayname = yield defer.ensureDeferred(
profile.get_displayname(target)
)
if displayname is not None: if displayname is not None:
content["displayname"] = displayname content["displayname"] = displayname
if "avatar_url" not in content: if "avatar_url" not in content:
avatar_url = yield profile.get_avatar_url(target) avatar_url = yield defer.ensureDeferred(
profile.get_avatar_url(target)
)
if avatar_url is not None: if avatar_url is not None:
content["avatar_url"] = avatar_url content["avatar_url"] = avatar_url
except Exception as e: except Exception as e:

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -54,16 +52,15 @@ class BaseProfileHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler() self.user_directory_handler = hs.get_user_directory_handler()
@defer.inlineCallbacks async def get_profile(self, user_id):
def get_profile(self, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
try: try:
displayname = yield self.store.get_profile_displayname( displayname = await self.store.get_profile_displayname(
target_user.localpart target_user.localpart
) )
avatar_url = yield self.store.get_profile_avatar_url( avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart target_user.localpart
) )
except StoreError as e: except StoreError as e:
@ -74,7 +71,7 @@ class BaseProfileHandler(BaseHandler):
return {"displayname": displayname, "avatar_url": avatar_url} return {"displayname": displayname, "avatar_url": avatar_url}
else: else:
try: try:
result = yield self.federation.make_query( result = await self.federation.make_query(
destination=target_user.domain, destination=target_user.domain,
query_type="profile", query_type="profile",
args={"user_id": user_id}, args={"user_id": user_id},
@ -86,8 +83,7 @@ class BaseProfileHandler(BaseHandler):
except HttpResponseException as e: except HttpResponseException as e:
raise e.to_synapse_error() raise e.to_synapse_error()
@defer.inlineCallbacks async def get_profile_from_cache(self, user_id):
def get_profile_from_cache(self, user_id):
"""Get the profile information from our local cache. If the user is """Get the profile information from our local cache. If the user is
ours then the profile information will always be corect. Otherwise, ours then the profile information will always be corect. Otherwise,
it may be out of date/missing. it may be out of date/missing.
@ -95,10 +91,10 @@ class BaseProfileHandler(BaseHandler):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
try: try:
displayname = yield self.store.get_profile_displayname( displayname = await self.store.get_profile_displayname(
target_user.localpart target_user.localpart
) )
avatar_url = yield self.store.get_profile_avatar_url( avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart target_user.localpart
) )
except StoreError as e: except StoreError as e:
@ -108,14 +104,13 @@ class BaseProfileHandler(BaseHandler):
return {"displayname": displayname, "avatar_url": avatar_url} return {"displayname": displayname, "avatar_url": avatar_url}
else: else:
profile = yield self.store.get_from_remote_profile_cache(user_id) profile = await self.store.get_from_remote_profile_cache(user_id)
return profile or {} return profile or {}
@defer.inlineCallbacks async def get_displayname(self, target_user):
def get_displayname(self, target_user):
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
try: try:
displayname = yield self.store.get_profile_displayname( displayname = await self.store.get_profile_displayname(
target_user.localpart target_user.localpart
) )
except StoreError as e: except StoreError as e:
@ -126,7 +121,7 @@ class BaseProfileHandler(BaseHandler):
return displayname return displayname
else: else:
try: try:
result = yield self.federation.make_query( result = await self.federation.make_query(
destination=target_user.domain, destination=target_user.domain,
query_type="profile", query_type="profile",
args={"user_id": target_user.to_string(), "field": "displayname"}, args={"user_id": target_user.to_string(), "field": "displayname"},
@ -189,11 +184,10 @@ class BaseProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user) await self._update_join_states(requester, target_user)
@defer.inlineCallbacks async def get_avatar_url(self, target_user):
def get_avatar_url(self, target_user):
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
try: try:
avatar_url = yield self.store.get_profile_avatar_url( avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart target_user.localpart
) )
except StoreError as e: except StoreError as e:
@ -203,7 +197,7 @@ class BaseProfileHandler(BaseHandler):
return avatar_url return avatar_url
else: else:
try: try:
result = yield self.federation.make_query( result = await self.federation.make_query(
destination=target_user.domain, destination=target_user.domain,
query_type="profile", query_type="profile",
args={"user_id": target_user.to_string(), "field": "avatar_url"}, args={"user_id": target_user.to_string(), "field": "avatar_url"},
@ -253,8 +247,7 @@ class BaseProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user) await self._update_join_states(requester, target_user)
@defer.inlineCallbacks async def on_profile_query(self, args):
def on_profile_query(self, args):
user = UserID.from_string(args["user_id"]) user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver") raise SynapseError(400, "User is not hosted on this homeserver")
@ -264,12 +257,12 @@ class BaseProfileHandler(BaseHandler):
response = {} response = {}
try: try:
if just_field is None or just_field == "displayname": if just_field is None or just_field == "displayname":
response["displayname"] = yield self.store.get_profile_displayname( response["displayname"] = await self.store.get_profile_displayname(
user.localpart user.localpart
) )
if just_field is None or just_field == "avatar_url": if just_field is None or just_field == "avatar_url":
response["avatar_url"] = yield self.store.get_profile_avatar_url( response["avatar_url"] = await self.store.get_profile_avatar_url(
user.localpart user.localpart
) )
except StoreError as e: except StoreError as e:
@ -304,8 +297,7 @@ class BaseProfileHandler(BaseHandler):
"Failed to update join event for room %s - %s", room_id, str(e) "Failed to update join event for room %s - %s", room_id, str(e)
) )
@defer.inlineCallbacks async def check_profile_query_allowed(self, target_user, requester=None):
def check_profile_query_allowed(self, target_user, requester=None):
"""Checks whether a profile query is allowed. If the """Checks whether a profile query is allowed. If the
'require_auth_for_profile_requests' config flag is set to True and a 'require_auth_for_profile_requests' config flag is set to True and a
'requester' is provided, the query is only allowed if the two users 'requester' is provided, the query is only allowed if the two users
@ -337,8 +329,8 @@ class BaseProfileHandler(BaseHandler):
return return
try: try:
requester_rooms = yield self.store.get_rooms_for_user(requester.to_string()) requester_rooms = await self.store.get_rooms_for_user(requester.to_string())
target_user_rooms = yield self.store.get_rooms_for_user( target_user_rooms = await self.store.get_rooms_for_user(
target_user.to_string() target_user.to_string()
) )
@ -371,25 +363,24 @@ class MasterProfileHandler(BaseProfileHandler):
"Update remote profile", self._update_remote_profile_cache "Update remote profile", self._update_remote_profile_cache
) )
@defer.inlineCallbacks async def _update_remote_profile_cache(self):
def _update_remote_profile_cache(self):
"""Called periodically to check profiles of remote users we haven't """Called periodically to check profiles of remote users we haven't
checked in a while. checked in a while.
""" """
entries = yield self.store.get_remote_profile_cache_entries_that_expire( entries = await self.store.get_remote_profile_cache_entries_that_expire(
last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
) )
for user_id, displayname, avatar_url in entries: for user_id, displayname, avatar_url in entries:
is_subscribed = yield self.store.is_subscribed_remote_profile_for_user( is_subscribed = await self.store.is_subscribed_remote_profile_for_user(
user_id user_id
) )
if not is_subscribed: if not is_subscribed:
yield self.store.maybe_delete_remote_profile_cache(user_id) await self.store.maybe_delete_remote_profile_cache(user_id)
continue continue
try: try:
profile = yield self.federation.make_query( profile = await self.federation.make_query(
destination=get_domain_from_id(user_id), destination=get_domain_from_id(user_id),
query_type="profile", query_type="profile",
args={"user_id": user_id}, args={"user_id": user_id},
@ -398,7 +389,7 @@ class MasterProfileHandler(BaseProfileHandler):
except Exception: except Exception:
logger.exception("Failed to get avatar_url") logger.exception("Failed to get avatar_url")
yield self.store.update_remote_profile_cache( await self.store.update_remote_profile_cache(
user_id, displayname, avatar_url user_id, displayname, avatar_url
) )
continue continue
@ -407,4 +398,4 @@ class MasterProfileHandler(BaseProfileHandler):
new_avatar = profile.get("avatar_url") new_avatar = profile.get("avatar_url")
# We always hit update to update the last_check timestamp # We always hit update to update the last_check timestamp
yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar) await self.store.update_remote_profile_cache(user_id, new_name, new_avatar)

View file

@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
import logging import logging
from twisted.internet import defer
from synapse.handlers._base import BaseHandler from synapse.handlers._base import BaseHandler
from synapse.types import ReadReceipt, get_domain_from_id from synapse.types import ReadReceipt, get_domain_from_id
from synapse.util.async_helpers import maybe_awaitable from synapse.util.async_helpers import maybe_awaitable
@ -129,15 +127,14 @@ class ReceiptEventSource(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks async def get_new_events(self, from_key, room_ids, **kwargs):
def get_new_events(self, from_key, room_ids, **kwargs):
from_key = int(from_key) from_key = int(from_key)
to_key = yield self.get_current_key() to_key = self.get_current_key()
if from_key == to_key: if from_key == to_key:
return [], to_key return [], to_key
events = yield self.store.get_linearized_receipts_for_rooms( events = await self.store.get_linearized_receipts_for_rooms(
room_ids, from_key=from_key, to_key=to_key room_ids, from_key=from_key, to_key=to_key
) )
@ -146,8 +143,7 @@ class ReceiptEventSource(object):
def get_current_key(self, direction="f"): def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id() return self.store.get_max_receipt_stream_id()
@defer.inlineCallbacks async def get_pagination_rows(self, user, config, key):
def get_pagination_rows(self, user, config, key):
to_key = int(config.from_key) to_key = int(config.from_key)
if config.to_key: if config.to_key:
@ -155,8 +151,8 @@ class ReceiptEventSource(object):
else: else:
from_key = None from_key = None
room_ids = yield self.store.get_rooms_for_user(user.to_string()) room_ids = await self.store.get_rooms_for_user(user.to_string())
events = yield self.store.get_linearized_receipts_for_rooms( events = await self.store.get_linearized_receipts_for_rooms(
room_ids, from_key=from_key, to_key=to_key room_ids, from_key=from_key, to_key=to_key
) )

View file

@ -72,7 +72,9 @@ class ProfileTestCase(unittest.TestCase):
def test_get_my_name(self): def test_get_my_name(self):
yield self.store.set_profile_displayname(self.frank.localpart, "Frank") yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
displayname = yield self.handler.get_displayname(self.frank) displayname = yield defer.ensureDeferred(
self.handler.get_displayname(self.frank)
)
self.assertEquals("Frank", displayname) self.assertEquals("Frank", displayname)
@ -140,7 +142,9 @@ class ProfileTestCase(unittest.TestCase):
{"displayname": "Alice"} {"displayname": "Alice"}
) )
displayname = yield self.handler.get_displayname(self.alice) displayname = yield defer.ensureDeferred(
self.handler.get_displayname(self.alice)
)
self.assertEquals(displayname, "Alice") self.assertEquals(displayname, "Alice")
self.mock_federation.make_query.assert_called_with( self.mock_federation.make_query.assert_called_with(
@ -155,8 +159,10 @@ class ProfileTestCase(unittest.TestCase):
yield self.store.create_profile("caroline") yield self.store.create_profile("caroline")
yield self.store.set_profile_displayname("caroline", "Caroline") yield self.store.set_profile_displayname("caroline", "Caroline")
response = yield self.query_handlers["profile"]( response = yield defer.ensureDeferred(
{"user_id": "@caroline:test", "field": "displayname"} self.query_handlers["profile"](
{"user_id": "@caroline:test", "field": "displayname"}
)
) )
self.assertEquals({"displayname": "Caroline"}, response) self.assertEquals({"displayname": "Caroline"}, response)
@ -166,8 +172,7 @@ class ProfileTestCase(unittest.TestCase):
yield self.store.set_profile_avatar_url( yield self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png" self.frank.localpart, "http://my.server/me.png"
) )
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
avatar_url = yield self.handler.get_avatar_url(self.frank)
self.assertEquals("http://my.server/me.png", avatar_url) self.assertEquals("http://my.server/me.png", avatar_url)