Add types to StreamToken and RoomStreamToken (#8279)

The intention here is to change `StreamToken.room_key` to be a `RoomStreamToken` in a future PR, but that is a big enough change without this refactoring too.
This commit is contained in:
Erik Johnston 2020-09-08 16:48:15 +01:00 committed by GitHub
parent 094896a69d
commit 63c0e9e195
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 95 additions and 91 deletions

1
changelog.d/8279.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `StreamToken` and `RoomStreamToken` classes.

View file

@ -1310,12 +1310,11 @@ class SyncHandler:
presence_source = self.event_sources.sources["presence"] presence_source = self.event_sources.sources["presence"]
since_token = sync_result_builder.since_token since_token = sync_result_builder.since_token
presence_key = None
include_offline = False
if since_token and not sync_result_builder.full_state: if since_token and not sync_result_builder.full_state:
presence_key = since_token.presence_key presence_key = since_token.presence_key
include_offline = True include_offline = True
else:
presence_key = None
include_offline = False
presence, presence_key = await presence_source.get_new_events( presence, presence_key = await presence_source.get_new_events(
user=user, user=user,

View file

@ -481,7 +481,7 @@ class DeviceWorkerStore(SQLBaseStore):
} }
async def get_users_whose_devices_changed( async def get_users_whose_devices_changed(
self, from_key: str, user_ids: Iterable[str] self, from_key: int, user_ids: Iterable[str]
) -> Set[str]: ) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that """Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids. are in the given list of user_ids.
@ -493,7 +493,6 @@ class DeviceWorkerStore(SQLBaseStore):
Returns: Returns:
The set of user_ids whose devices have changed since `from_key` The set of user_ids whose devices have changed since `from_key`
""" """
from_key = int(from_key)
# Get set of users who *may* have changed. Users not in the returned # Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed. # list have definitely not changed.
@ -527,7 +526,7 @@ class DeviceWorkerStore(SQLBaseStore):
) )
async def get_users_whose_signatures_changed( async def get_users_whose_signatures_changed(
self, user_id: str, from_key: str self, user_id: str, from_key: int
) -> Set[str]: ) -> Set[str]:
"""Get the users who have new cross-signing signatures made by `user_id` since """Get the users who have new cross-signing signatures made by `user_id` since
`from_key`. `from_key`.
@ -539,7 +538,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns: Returns:
A set of user IDs with updated signatures. A set of user IDs with updated signatures.
""" """
from_key = int(from_key)
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key): if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
sql = """ sql = """
SELECT DISTINCT user_ids FROM user_signature_stream SELECT DISTINCT user_ids FROM user_signature_stream

View file

@ -79,8 +79,8 @@ _EventDictReturn = namedtuple(
def generate_pagination_where_clause( def generate_pagination_where_clause(
direction: str, direction: str,
column_names: Tuple[str, str], column_names: Tuple[str, str],
from_token: Optional[Tuple[int, int]], from_token: Optional[Tuple[Optional[int], int]],
to_token: Optional[Tuple[int, int]], to_token: Optional[Tuple[Optional[int], int]],
engine: BaseDatabaseEngine, engine: BaseDatabaseEngine,
) -> str: ) -> str:
"""Creates an SQL expression to bound the columns by the pagination """Creates an SQL expression to bound the columns by the pagination
@ -535,13 +535,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if limit == 0: if limit == 0:
return [], end_token return [], end_token
end_token = RoomStreamToken.parse(end_token) parsed_end_token = RoomStreamToken.parse(end_token)
rows, token = await self.db_pool.runInteraction( rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room", "get_recent_event_ids_for_room",
self._paginate_room_events_txn, self._paginate_room_events_txn,
room_id, room_id,
from_token=end_token, from_token=parsed_end_token,
limit=limit, limit=limit,
) )
@ -989,8 +989,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
bounds = generate_pagination_where_clause( bounds = generate_pagination_where_clause(
direction=direction, direction=direction,
column_names=("topological_ordering", "stream_ordering"), column_names=("topological_ordering", "stream_ordering"),
from_token=from_token, from_token=from_token.as_tuple(),
to_token=to_token, to_token=to_token.as_tuple() if to_token else None,
engine=self.database_engine, engine=self.database_engine,
) )
@ -1083,16 +1083,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
and `to_key`). and `to_key`).
""" """
from_key = RoomStreamToken.parse(from_key) parsed_from_key = RoomStreamToken.parse(from_key)
parsed_to_key = None
if to_key: if to_key:
to_key = RoomStreamToken.parse(to_key) parsed_to_key = RoomStreamToken.parse(to_key)
rows, token = await self.db_pool.runInteraction( rows, token = await self.db_pool.runInteraction(
"paginate_room_events", "paginate_room_events",
self._paginate_room_events_txn, self._paginate_room_events_txn,
room_id, room_id,
from_key, parsed_from_key,
to_key, parsed_to_key,
direction, direction,
limit, limit,
event_filter, event_filter,

View file

@ -18,7 +18,7 @@ import re
import string import string
import sys import sys
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar
import attr import attr
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -362,22 +362,79 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
return username.decode("ascii") return username.decode("ascii")
class StreamToken( @attr.s(frozen=True, slots=True)
namedtuple( class RoomStreamToken:
"Token", """Tokens are positions between events. The token "s1" comes after event 1.
(
"room_key", s0 s1
"presence_key", | |
"typing_key", [0] V [1] V [2]
"receipt_key",
"account_data_key", Tokens can either be a point in the live event stream or a cursor going
"push_rules_key", through historic events.
"to_device_key",
"device_list_key", When traversing the live event stream events are ordered by when they
"groups_key", arrived at the homeserver.
),
When traversing historic events the events are ordered by their depth in
the event graph "topological_ordering" and then by when they arrived at the
homeserver "stream_ordering".
Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after.
"""
topological = attr.ib(
type=Optional[int],
validator=attr.validators.optional(attr.validators.instance_of(int)),
) )
): stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
@classmethod
def parse(cls, string: str) -> "RoomStreamToken":
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
if string[0] == "t":
parts = string[1:].split("-", 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@classmethod
def parse_stream_token(cls, string: str) -> "RoomStreamToken":
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
def as_tuple(self) -> Tuple[Optional[int], int]:
return (self.topological, self.stream)
def __str__(self) -> str:
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)
@attr.s(slots=True, frozen=True)
class StreamToken:
room_key = attr.ib(type=str)
presence_key = attr.ib(type=int)
typing_key = attr.ib(type=int)
receipt_key = attr.ib(type=int)
account_data_key = attr.ib(type=int)
push_rules_key = attr.ib(type=int)
to_device_key = attr.ib(type=int)
device_list_key = attr.ib(type=int)
groups_key = attr.ib(type=int)
_SEPARATOR = "_" _SEPARATOR = "_"
START = None # type: StreamToken START = None # type: StreamToken
@ -385,15 +442,15 @@ class StreamToken(
def from_string(cls, string): def from_string(cls, string):
try: try:
keys = string.split(cls._SEPARATOR) keys = string.split(cls._SEPARATOR)
while len(keys) < len(cls._fields): while len(keys) < len(attr.fields(cls)):
# i.e. old token from before receipt_key # i.e. old token from before receipt_key
keys.append("0") keys.append("0")
return cls(*keys) return cls(keys[0], *(int(k) for k in keys[1:]))
except Exception: except Exception:
raise SynapseError(400, "Invalid Token") raise SynapseError(400, "Invalid Token")
def to_string(self): def to_string(self):
return self._SEPARATOR.join([str(k) for k in self]) return self._SEPARATOR.join([str(k) for k in attr.astuple(self)])
@property @property
def room_stream_id(self): def room_stream_id(self):
@ -435,63 +492,10 @@ class StreamToken(
return self return self
def copy_and_replace(self, key, new_value): def copy_and_replace(self, key, new_value):
return self._replace(**{key: new_value}) return attr.evolve(self, **{key: new_value})
StreamToken.START = StreamToken(*(["s0"] + ["0"] * (len(StreamToken._fields) - 1))) StreamToken.START = StreamToken.from_string("s0_0")
class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
"""Tokens are positions between events. The token "s1" comes after event 1.
s0 s1
| |
[0] V [1] V [2]
Tokens can either be a point in the live event stream or a cursor going
through historic events.
When traversing the live event stream events are ordered by when they
arrived at the homeserver.
When traversing historic events the events are ordered by their depth in
the event graph "topological_ordering" and then by when they arrived at the
homeserver "stream_ordering".
Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after.
"""
__slots__ = [] # type: list
@classmethod
def parse(cls, string):
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
if string[0] == "t":
parts = string[1:].split("-", 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@classmethod
def parse_stream_token(cls, string):
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
def __str__(self):
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)
class ThirdPartyInstanceID( class ThirdPartyInstanceID(