Add type hints to the push mailer module. (#8882)

This commit is contained in:
Patrick Cloke 2020-12-07 07:10:22 -05:00 committed by GitHub
parent 96358cb424
commit 02e588856a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 85 additions and 40 deletions

1
changelog.d/8882.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to push module.

View file

@ -55,6 +55,7 @@ files =
synapse/metrics, synapse/metrics,
synapse/module_api, synapse/module_api,
synapse/notifier.py, synapse/notifier.py,
synapse/push/mailer.py,
synapse/push/pusherpool.py, synapse/push/pusherpool.py,
synapse/push/push_rule_evaluator.py, synapse/push/push_rule_evaluator.py,
synapse/replication, synapse/replication,

View file

@ -19,7 +19,7 @@ import logging
import urllib.parse import urllib.parse
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText from email.mime.text import MIMEText
from typing import Iterable, List, TypeVar from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
import bleach import bleach
import jinja2 import jinja2
@ -27,16 +27,20 @@ import jinja2
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.config.emailconfig import EmailSubjectConfig from synapse.config.emailconfig import EmailSubjectConfig
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.push.presentable_names import ( from synapse.push.presentable_names import (
calculate_room_name, calculate_room_name,
descriptor_from_member_events, descriptor_from_member_events,
name_from_member_event, name_from_member_event,
) )
from synapse.types import UserID from synapse.types import StateMap, UserID
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T = TypeVar("T") T = TypeVar("T")
@ -93,7 +97,13 @@ ALLOWED_ATTRS = {
class Mailer: class Mailer:
def __init__(self, hs, app_name, template_html, template_text): def __init__(
self,
hs: "HomeServer",
app_name: str,
template_html: jinja2.Template,
template_text: jinja2.Template,
):
self.hs = hs self.hs = hs
self.template_html = template_html self.template_html = template_html
self.template_text = template_text self.template_text = template_text
@ -108,17 +118,19 @@ class Mailer:
logger.info("Created Mailer for app_name %s" % app_name) logger.info("Created Mailer for app_name %s" % app_name)
async def send_password_reset_mail(self, email_address, token, client_secret, sid): async def send_password_reset_mail(
self, email_address: str, token: str, client_secret: str, sid: str
) -> None:
"""Send an email with a password reset link to a user """Send an email with a password reset link to a user
Args: Args:
email_address (str): Email address we're sending the password email_address: Email address we're sending the password
reset to reset to
token (str): Unique token generated by the server to verify token: Unique token generated by the server to verify
the email was received the email was received
client_secret (str): Unique token generated by the client to client_secret: Unique token generated by the client to
group together multiple email sending attempts group together multiple email sending attempts
sid (str): The generated session ID sid: The generated session ID
""" """
params = {"token": token, "client_secret": client_secret, "sid": sid} params = {"token": token, "client_secret": client_secret, "sid": sid}
link = ( link = (
@ -136,17 +148,19 @@ class Mailer:
template_vars, template_vars,
) )
async def send_registration_mail(self, email_address, token, client_secret, sid): async def send_registration_mail(
self, email_address: str, token: str, client_secret: str, sid: str
) -> None:
"""Send an email with a registration confirmation link to a user """Send an email with a registration confirmation link to a user
Args: Args:
email_address (str): Email address we're sending the registration email_address: Email address we're sending the registration
link to link to
token (str): Unique token generated by the server to verify token: Unique token generated by the server to verify
the email was received the email was received
client_secret (str): Unique token generated by the client to client_secret: Unique token generated by the client to
group together multiple email sending attempts group together multiple email sending attempts
sid (str): The generated session ID sid: The generated session ID
""" """
params = {"token": token, "client_secret": client_secret, "sid": sid} params = {"token": token, "client_secret": client_secret, "sid": sid}
link = ( link = (
@ -164,18 +178,20 @@ class Mailer:
template_vars, template_vars,
) )
async def send_add_threepid_mail(self, email_address, token, client_secret, sid): async def send_add_threepid_mail(
self, email_address: str, token: str, client_secret: str, sid: str
) -> None:
"""Send an email with a validation link to a user for adding a 3pid to their account """Send an email with a validation link to a user for adding a 3pid to their account
Args: Args:
email_address (str): Email address we're sending the validation link to email_address: Email address we're sending the validation link to
token (str): Unique token generated by the server to verify the email was received token: Unique token generated by the server to verify the email was received
client_secret (str): Unique token generated by the client to group together client_secret: Unique token generated by the client to group together
multiple email sending attempts multiple email sending attempts
sid (str): The generated session ID sid: The generated session ID
""" """
params = {"token": token, "client_secret": client_secret, "sid": sid} params = {"token": token, "client_secret": client_secret, "sid": sid}
link = ( link = (
@ -194,8 +210,13 @@ class Mailer:
) )
async def send_notification_mail( async def send_notification_mail(
self, app_id, user_id, email_address, push_actions, reason self,
): app_id: str,
user_id: str,
email_address: str,
push_actions: Iterable[Dict[str, Any]],
reason: Dict[str, Any],
) -> None:
"""Send email regarding a user's room notifications""" """Send email regarding a user's room notifications"""
rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions]) rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
@ -203,7 +224,7 @@ class Mailer:
[pa["event_id"] for pa in push_actions] [pa["event_id"] for pa in push_actions]
) )
notifs_by_room = {} notifs_by_room = {} # type: Dict[str, List[Dict[str, Any]]]
for pa in push_actions: for pa in push_actions:
notifs_by_room.setdefault(pa["room_id"], []).append(pa) notifs_by_room.setdefault(pa["room_id"], []).append(pa)
@ -262,7 +283,9 @@ class Mailer:
await self.send_email(email_address, summary_text, template_vars) await self.send_email(email_address, summary_text, template_vars)
async def send_email(self, email_address, subject, extra_template_vars): async def send_email(
self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
) -> None:
"""Send an email with the given information and template text""" """Send an email with the given information and template text"""
try: try:
from_string = self.hs.config.email_notif_from % {"app": self.app_name} from_string = self.hs.config.email_notif_from % {"app": self.app_name}
@ -315,8 +338,13 @@ class Mailer:
) )
async def get_room_vars( async def get_room_vars(
self, room_id, user_id, notifs, notif_events, room_state_ids self,
): room_id: str,
user_id: str,
notifs: Iterable[Dict[str, Any]],
notif_events: Dict[str, EventBase],
room_state_ids: StateMap[str],
) -> Dict[str, Any]:
# Check if one of the notifs is an invite event for the user. # Check if one of the notifs is an invite event for the user.
is_invite = False is_invite = False
for n in notifs: for n in notifs:
@ -334,7 +362,7 @@ class Mailer:
"notifs": [], "notifs": [],
"invite": is_invite, "invite": is_invite,
"link": self.make_room_link(room_id), "link": self.make_room_link(room_id),
} } # type: Dict[str, Any]
if not is_invite: if not is_invite:
for n in notifs: for n in notifs:
@ -365,7 +393,13 @@ class Mailer:
return room_vars return room_vars
async def get_notif_vars(self, notif, user_id, notif_event, room_state_ids): async def get_notif_vars(
self,
notif: Dict[str, Any],
user_id: str,
notif_event: EventBase,
room_state_ids: StateMap[str],
) -> Dict[str, Any]:
results = await self.store.get_events_around( results = await self.store.get_events_around(
notif["room_id"], notif["room_id"],
notif["event_id"], notif["event_id"],
@ -391,7 +425,9 @@ class Mailer:
return ret return ret
async def get_message_vars(self, notif, event, room_state_ids): async def get_message_vars(
self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str]
) -> Optional[Dict[str, Any]]:
if event.type != EventTypes.Message and event.type != EventTypes.Encrypted: if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
return None return None
@ -432,7 +468,9 @@ class Mailer:
return ret return ret
def add_text_message_vars(self, messagevars, event): def add_text_message_vars(
self, messagevars: Dict[str, Any], event: EventBase
) -> None:
msgformat = event.content.get("format") msgformat = event.content.get("format")
messagevars["format"] = msgformat messagevars["format"] = msgformat
@ -445,15 +483,18 @@ class Mailer:
elif body: elif body:
messagevars["body_text_html"] = safe_text(body) messagevars["body_text_html"] = safe_text(body)
return messagevars def add_image_message_vars(
self, messagevars: Dict[str, Any], event: EventBase
def add_image_message_vars(self, messagevars, event): ) -> None:
messagevars["image_url"] = event.content["url"] messagevars["image_url"] = event.content["url"]
return messagevars
async def make_summary_text( async def make_summary_text(
self, notifs_by_room, room_state_ids, notif_events, user_id, reason self,
notifs_by_room: Dict[str, List[Dict[str, Any]]],
room_state_ids: Dict[str, StateMap[str]],
notif_events: Dict[str, EventBase],
user_id: str,
reason: Dict[str, Any],
): ):
if len(notifs_by_room) == 1: if len(notifs_by_room) == 1:
# Only one room has new stuff # Only one room has new stuff
@ -580,7 +621,7 @@ class Mailer:
"app": self.app_name, "app": self.app_name,
} }
def make_room_link(self, room_id): def make_room_link(self, room_id: str) -> str:
if self.hs.config.email_riot_base_url: if self.hs.config.email_riot_base_url:
base_url = "%s/#/room" % (self.hs.config.email_riot_base_url) base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
elif self.app_name == "Vector": elif self.app_name == "Vector":
@ -590,7 +631,7 @@ class Mailer:
base_url = "https://matrix.to/#" base_url = "https://matrix.to/#"
return "%s/%s" % (base_url, room_id) return "%s/%s" % (base_url, room_id)
def make_notif_link(self, notif): def make_notif_link(self, notif: Dict[str, str]) -> str:
if self.hs.config.email_riot_base_url: if self.hs.config.email_riot_base_url:
return "%s/#/room/%s/%s" % ( return "%s/#/room/%s/%s" % (
self.hs.config.email_riot_base_url, self.hs.config.email_riot_base_url,
@ -606,7 +647,9 @@ class Mailer:
else: else:
return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"]) return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])
def make_unsubscribe_link(self, user_id, app_id, email_address): def make_unsubscribe_link(
self, user_id: str, app_id: str, email_address: str
) -> str:
params = { params = {
"access_token": self.macaroon_gen.generate_delete_pusher_token(user_id), "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
"app_id": app_id, "app_id": app_id,
@ -620,7 +663,7 @@ class Mailer:
) )
def safe_markup(raw_html): def safe_markup(raw_html: str) -> jinja2.Markup:
return jinja2.Markup( return jinja2.Markup(
bleach.linkify( bleach.linkify(
bleach.clean( bleach.clean(
@ -635,7 +678,7 @@ def safe_markup(raw_html):
) )
def safe_text(raw_text): def safe_text(raw_text: str) -> jinja2.Markup:
""" """
Process text: treat it as HTML but escape any tags (ie. just escape the Process text: treat it as HTML but escape any tags (ie. just escape the
HTML) then linkify it. HTML) then linkify it.
@ -655,7 +698,7 @@ def deduped_ordered_list(it: Iterable[T]) -> List[T]:
return ret return ret
def string_ordinal_total(s): def string_ordinal_total(s: str) -> int:
tot = 0 tot = 0
for c in s: for c in s:
tot += ord(c) tot += ord(c)