diff --git a/changelog.d/11312.misc b/changelog.d/11312.misc new file mode 100644 index 0000000000..86594a332d --- /dev/null +++ b/changelog.d/11312.misc @@ -0,0 +1 @@ +Add type hints to storage classes. diff --git a/mypy.ini b/mypy.ini index 48dfdfa0e0..3b7e1eb708 100644 --- a/mypy.ini +++ b/mypy.ini @@ -48,7 +48,6 @@ exclude = (?x) |synapse/storage/databases/main/room.py |synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/search.py - |synapse/storage/databases/main/signatures.py |synapse/storage/databases/main/state.py |synapse/storage/databases/main/state_deltas.py |synapse/storage/databases/main/stats.py diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 4f409f31e1..eb39e0ae32 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -128,14 +128,12 @@ class EventBuilder: ) format_version = self.room_version.event_format + # The types of auth/prev events changes between event versions. + prev_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] + auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] if format_version == EventFormatVersions.V1: - # The types of auth/prev events changes between event versions. - auth_events: Union[ - List[str], List[Tuple[str, Dict[str, str]]] - ] = await self._store.add_event_hashes(auth_event_ids) - prev_events: Union[ - List[str], List[Tuple[str, Dict[str, str]]] - ] = await self._store.add_event_hashes(prev_event_ids) + auth_events = await self._store.add_event_hashes(auth_event_ids) + prev_events = await self._store.add_event_hashes(prev_event_ids) else: auth_events = auth_event_ids prev_events = prev_event_ids diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index ab2159c2d3..3201623fe4 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -63,12 +63,12 @@ class SignatureWorkerStore(SQLBaseStore): A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash. """ hashes = await self.get_event_reference_hashes(event_ids) - hashes = { + encoded_hashes = { e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} for e_id, h in hashes.items() } - return list(hashes.items()) + return list(encoded_hashes.items()) def _get_event_reference_hashes_txn( self, txn: Cursor, event_id: str