mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-25 19:15:51 +03:00
Sign outgoing PDUs.
This commit is contained in:
parent
1c445f88f6
commit
66104da10c
9 changed files with 62 additions and 24 deletions
|
@ -15,6 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from synapse.federation.units import Pdu
|
||||||
from synapse.api.events.utils import prune_pdu
|
from synapse.api.events.utils import prune_pdu
|
||||||
from syutil.jsonutil import encode_canonical_json
|
from syutil.jsonutil import encode_canonical_json
|
||||||
from syutil.base64util import encode_base64, decode_base64
|
from syutil.base64util import encode_base64, decode_base64
|
||||||
|
@ -25,8 +26,7 @@ import hashlib
|
||||||
|
|
||||||
def hash_event_pdu(pdu, hash_algortithm=hashlib.sha256):
|
def hash_event_pdu(pdu, hash_algortithm=hashlib.sha256):
|
||||||
hashed = _compute_hash(pdu, hash_algortithm)
|
hashed = _compute_hash(pdu, hash_algortithm)
|
||||||
hashes[hashed.name] = encode_base64(hashed.digest())
|
pdu.hashes[hashed.name] = encode_base64(hashed.digest())
|
||||||
pdu.hashes = hashes
|
|
||||||
return pdu
|
return pdu
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .units import Pdu
|
from .units import Pdu
|
||||||
|
from synapse.crypto.event_signing import hash_event_pdu, sign_event_pdu
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
@ -33,6 +34,7 @@ def encode_event_id(pdu_id, origin):
|
||||||
class PduCodec(object):
|
class PduCodec(object):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
self.signing_key = hs.config.signing_key[0]
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.event_factory = hs.get_event_factory()
|
self.event_factory = hs.get_event_factory()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
@ -99,4 +101,6 @@ class PduCodec(object):
|
||||||
if "ts" not in kwargs:
|
if "ts" not in kwargs:
|
||||||
kwargs["ts"] = int(self.clock.time_msec())
|
kwargs["ts"] = int(self.clock.time_msec())
|
||||||
|
|
||||||
return Pdu(**kwargs)
|
pdu = Pdu(**kwargs)
|
||||||
|
pdu = hash_event_pdu(pdu)
|
||||||
|
return sign_event_pdu(pdu, self.server_name, self.signing_key)
|
||||||
|
|
|
@ -42,6 +42,7 @@ from .transactions import TransactionStore
|
||||||
from .keys import KeyStore
|
from .keys import KeyStore
|
||||||
from .signatures import SignatureStore
|
from .signatures import SignatureStore
|
||||||
|
|
||||||
|
from syutil.base64util import decode_base64
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
@ -168,11 +169,11 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
txn, pdu.pdu_id, pdu.origin, hash_alg, hash_bytes,
|
txn, pdu.pdu_id, pdu.origin, hash_alg, hash_bytes,
|
||||||
)
|
)
|
||||||
|
|
||||||
signatures = pdu.sigatures.get(pdu.orgin, {})
|
signatures = pdu.signatures.get(pdu.origin, {})
|
||||||
|
|
||||||
for key_id, signature_base64 in signatures:
|
for key_id, signature_base64 in signatures.items():
|
||||||
signature_bytes = decode_base64(signature_base64)
|
signature_bytes = decode_base64(signature_base64)
|
||||||
self.store_pdu_origin_signatures_txn(
|
self._store_pdu_origin_signature_txn(
|
||||||
txn, pdu.pdu_id, pdu.origin, key_id, signature_bytes,
|
txn, pdu.pdu_id, pdu.origin, key_id, signature_bytes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -47,7 +47,7 @@ class SignatureStore(SQLBaseStore):
|
||||||
algorithm (str): Hashing algorithm.
|
algorithm (str): Hashing algorithm.
|
||||||
hash_bytes (bytes): Hash function output bytes.
|
hash_bytes (bytes): Hash function output bytes.
|
||||||
"""
|
"""
|
||||||
self._simple_insert_txn(self, txn, "pdu_hashes", {
|
self._simple_insert_txn(txn, "pdu_hashes", {
|
||||||
"pdu_id": pdu_id,
|
"pdu_id": pdu_id,
|
||||||
"origin": origin,
|
"origin": origin,
|
||||||
"algorithm": algorithm,
|
"algorithm": algorithm,
|
||||||
|
@ -66,7 +66,7 @@ class SignatureStore(SQLBaseStore):
|
||||||
query = (
|
query = (
|
||||||
"SELECT key_id, signature"
|
"SELECT key_id, signature"
|
||||||
" FROM pdu_origin_signatures"
|
" FROM pdu_origin_signatures"
|
||||||
" WHERE WHERE pdu_id = ? and origin = ?"
|
" WHERE pdu_id = ? and origin = ?"
|
||||||
)
|
)
|
||||||
txn.execute(query, (pdu_id, origin))
|
txn.execute(query, (pdu_id, origin))
|
||||||
return dict(txn.fetchall())
|
return dict(txn.fetchall())
|
||||||
|
@ -81,7 +81,7 @@ class SignatureStore(SQLBaseStore):
|
||||||
key_id (str): Id for the signing key.
|
key_id (str): Id for the signing key.
|
||||||
signature (bytes): The signature.
|
signature (bytes): The signature.
|
||||||
"""
|
"""
|
||||||
self._simple_insert_txn(self, txn, "pdu_origin_signatures", {
|
self._simple_insert_txn(txn, "pdu_origin_signatures", {
|
||||||
"pdu_id": pdu_id,
|
"pdu_id": pdu_id,
|
||||||
"origin": origin,
|
"origin": origin,
|
||||||
"key_id": key_id,
|
"key_id": key_id,
|
||||||
|
|
|
@ -23,14 +23,21 @@ from synapse.federation.units import Pdu
|
||||||
|
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock, NonCallableMock
|
||||||
|
|
||||||
|
from ..utils import MockKey
|
||||||
|
|
||||||
|
|
||||||
class PduCodecTestCase(unittest.TestCase):
|
class PduCodecTestCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.hs = HomeServer("blargle.net")
|
self.mock_config = NonCallableMock()
|
||||||
self.event_factory = self.hs.get_event_factory()
|
self.mock_config.signing_key = [MockKey()]
|
||||||
|
|
||||||
|
self.hs = HomeServer(
|
||||||
|
"blargle.net",
|
||||||
|
config=self.mock_config,
|
||||||
|
)
|
||||||
|
self.event_factory = self.hs.get_event_factory()
|
||||||
self.codec = PduCodec(self.hs)
|
self.codec = PduCodec(self.hs)
|
||||||
|
|
||||||
def test_decode_event_id(self):
|
def test_decode_event_id(self):
|
||||||
|
|
|
@ -28,7 +28,7 @@ from synapse.server import HomeServer
|
||||||
# python imports
|
# python imports
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from ..utils import MockHttpResource, MemoryDataStore
|
from ..utils import MockHttpResource, MemoryDataStore, MockKey
|
||||||
from .utils import RestTestCase
|
from .utils import RestTestCase
|
||||||
|
|
||||||
from mock import Mock, NonCallableMock
|
from mock import Mock, NonCallableMock
|
||||||
|
@ -122,6 +122,9 @@ class EventStreamPermissionsTestCase(RestTestCase):
|
||||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||||
|
|
||||||
|
self.mock_config = NonCallableMock()
|
||||||
|
self.mock_config.signing_key = [MockKey()]
|
||||||
|
|
||||||
hs = HomeServer(
|
hs = HomeServer(
|
||||||
"test",
|
"test",
|
||||||
db_pool=None,
|
db_pool=None,
|
||||||
|
@ -139,7 +142,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
|
||||||
ratelimiter=NonCallableMock(spec_set=[
|
ratelimiter=NonCallableMock(spec_set=[
|
||||||
"send_message",
|
"send_message",
|
||||||
]),
|
]),
|
||||||
config=NonCallableMock(),
|
config=self.mock_config,
|
||||||
)
|
)
|
||||||
self.ratelimiter = hs.get_ratelimiter()
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
self.ratelimiter.send_message.return_value = (True, 0)
|
||||||
|
|
|
@ -18,9 +18,9 @@
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock, NonCallableMock
|
||||||
|
|
||||||
from ..utils import MockHttpResource
|
from ..utils import MockHttpResource, MockKey
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError, AuthError
|
from synapse.api.errors import SynapseError, AuthError
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -41,6 +41,9 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
"set_avatar_url",
|
"set_avatar_url",
|
||||||
])
|
])
|
||||||
|
|
||||||
|
self.mock_config = NonCallableMock()
|
||||||
|
self.mock_config.signing_key = [MockKey()]
|
||||||
|
|
||||||
hs = HomeServer("test",
|
hs = HomeServer("test",
|
||||||
db_pool=None,
|
db_pool=None,
|
||||||
http_client=None,
|
http_client=None,
|
||||||
|
@ -48,6 +51,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
federation=Mock(),
|
federation=Mock(),
|
||||||
replication_layer=Mock(),
|
replication_layer=Mock(),
|
||||||
datastore=None,
|
datastore=None,
|
||||||
|
config=self.mock_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_user_by_req(request=None):
|
def _get_user_by_req(request=None):
|
||||||
|
|
|
@ -27,7 +27,7 @@ from synapse.server import HomeServer
|
||||||
import json
|
import json
|
||||||
import urllib
|
import urllib
|
||||||
|
|
||||||
from ..utils import MockHttpResource, MemoryDataStore
|
from ..utils import MockHttpResource, MemoryDataStore, MockKey
|
||||||
from .utils import RestTestCase
|
from .utils import RestTestCase
|
||||||
|
|
||||||
from mock import Mock, NonCallableMock
|
from mock import Mock, NonCallableMock
|
||||||
|
@ -50,6 +50,9 @@ class RoomPermissionsTestCase(RestTestCase):
|
||||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||||
|
|
||||||
|
self.mock_config = NonCallableMock()
|
||||||
|
self.mock_config.signing_key = [MockKey()]
|
||||||
|
|
||||||
hs = HomeServer(
|
hs = HomeServer(
|
||||||
"red",
|
"red",
|
||||||
db_pool=None,
|
db_pool=None,
|
||||||
|
@ -61,7 +64,7 @@ class RoomPermissionsTestCase(RestTestCase):
|
||||||
ratelimiter=NonCallableMock(spec_set=[
|
ratelimiter=NonCallableMock(spec_set=[
|
||||||
"send_message",
|
"send_message",
|
||||||
]),
|
]),
|
||||||
config=NonCallableMock(),
|
config=self.mock_config,
|
||||||
)
|
)
|
||||||
self.ratelimiter = hs.get_ratelimiter()
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
self.ratelimiter.send_message.return_value = (True, 0)
|
||||||
|
@ -408,6 +411,9 @@ class RoomsMemberListTestCase(RestTestCase):
|
||||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||||
|
|
||||||
|
self.mock_config = NonCallableMock()
|
||||||
|
self.mock_config.signing_key = [MockKey()]
|
||||||
|
|
||||||
hs = HomeServer(
|
hs = HomeServer(
|
||||||
"red",
|
"red",
|
||||||
db_pool=None,
|
db_pool=None,
|
||||||
|
@ -419,7 +425,7 @@ class RoomsMemberListTestCase(RestTestCase):
|
||||||
ratelimiter=NonCallableMock(spec_set=[
|
ratelimiter=NonCallableMock(spec_set=[
|
||||||
"send_message",
|
"send_message",
|
||||||
]),
|
]),
|
||||||
config=NonCallableMock(),
|
config=self.mock_config,
|
||||||
)
|
)
|
||||||
self.ratelimiter = hs.get_ratelimiter()
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
self.ratelimiter.send_message.return_value = (True, 0)
|
||||||
|
@ -497,6 +503,9 @@ class RoomsCreateTestCase(RestTestCase):
|
||||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||||
|
|
||||||
|
self.mock_config = NonCallableMock()
|
||||||
|
self.mock_config.signing_key = [MockKey()]
|
||||||
|
|
||||||
hs = HomeServer(
|
hs = HomeServer(
|
||||||
"red",
|
"red",
|
||||||
db_pool=None,
|
db_pool=None,
|
||||||
|
@ -508,7 +517,7 @@ class RoomsCreateTestCase(RestTestCase):
|
||||||
ratelimiter=NonCallableMock(spec_set=[
|
ratelimiter=NonCallableMock(spec_set=[
|
||||||
"send_message",
|
"send_message",
|
||||||
]),
|
]),
|
||||||
config=NonCallableMock(),
|
config=self.mock_config,
|
||||||
)
|
)
|
||||||
self.ratelimiter = hs.get_ratelimiter()
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
self.ratelimiter.send_message.return_value = (True, 0)
|
||||||
|
@ -598,6 +607,9 @@ class RoomTopicTestCase(RestTestCase):
|
||||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||||
|
|
||||||
|
self.mock_config = NonCallableMock()
|
||||||
|
self.mock_config.signing_key = [MockKey()]
|
||||||
|
|
||||||
hs = HomeServer(
|
hs = HomeServer(
|
||||||
"red",
|
"red",
|
||||||
db_pool=None,
|
db_pool=None,
|
||||||
|
@ -609,7 +621,7 @@ class RoomTopicTestCase(RestTestCase):
|
||||||
ratelimiter=NonCallableMock(spec_set=[
|
ratelimiter=NonCallableMock(spec_set=[
|
||||||
"send_message",
|
"send_message",
|
||||||
]),
|
]),
|
||||||
config=NonCallableMock(),
|
config=self.mock_config,
|
||||||
)
|
)
|
||||||
self.ratelimiter = hs.get_ratelimiter()
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
self.ratelimiter.send_message.return_value = (True, 0)
|
||||||
|
@ -712,6 +724,9 @@ class RoomMemberStateTestCase(RestTestCase):
|
||||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||||
|
|
||||||
|
self.mock_config = NonCallableMock()
|
||||||
|
self.mock_config.signing_key = [MockKey()]
|
||||||
|
|
||||||
hs = HomeServer(
|
hs = HomeServer(
|
||||||
"red",
|
"red",
|
||||||
db_pool=None,
|
db_pool=None,
|
||||||
|
@ -723,7 +738,7 @@ class RoomMemberStateTestCase(RestTestCase):
|
||||||
ratelimiter=NonCallableMock(spec_set=[
|
ratelimiter=NonCallableMock(spec_set=[
|
||||||
"send_message",
|
"send_message",
|
||||||
]),
|
]),
|
||||||
config=NonCallableMock(),
|
config=self.mock_config,
|
||||||
)
|
)
|
||||||
self.ratelimiter = hs.get_ratelimiter()
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
self.ratelimiter.send_message.return_value = (True, 0)
|
||||||
|
@ -853,6 +868,9 @@ class RoomMessagesTestCase(RestTestCase):
|
||||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||||
|
|
||||||
|
self.mock_config = NonCallableMock()
|
||||||
|
self.mock_config.signing_key = [MockKey()]
|
||||||
|
|
||||||
hs = HomeServer(
|
hs = HomeServer(
|
||||||
"red",
|
"red",
|
||||||
db_pool=None,
|
db_pool=None,
|
||||||
|
@ -864,7 +882,7 @@ class RoomMessagesTestCase(RestTestCase):
|
||||||
ratelimiter=NonCallableMock(spec_set=[
|
ratelimiter=NonCallableMock(spec_set=[
|
||||||
"send_message",
|
"send_message",
|
||||||
]),
|
]),
|
||||||
config=NonCallableMock(),
|
config=self.mock_config,
|
||||||
)
|
)
|
||||||
self.ratelimiter = hs.get_ratelimiter()
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
self.ratelimiter.send_message.return_value = (True, 0)
|
||||||
|
|
|
@ -118,13 +118,14 @@ class MockHttpResource(HttpServer):
|
||||||
class MockKey(object):
|
class MockKey(object):
|
||||||
alg = "mock_alg"
|
alg = "mock_alg"
|
||||||
version = "mock_version"
|
version = "mock_version"
|
||||||
|
signature = b"\x9a\x87$"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def verify_key(self):
|
def verify_key(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def sign(self, message):
|
def sign(self, message):
|
||||||
return b"\x9a\x87$"
|
return self
|
||||||
|
|
||||||
def verify(self, message, sig):
|
def verify(self, message, sig):
|
||||||
assert sig == b"\x9a\x87$"
|
assert sig == b"\x9a\x87$"
|
||||||
|
|
Loading…
Reference in a new issue