Sign outgoing PDUs.

This commit is contained in:
Mark Haines 2014-10-16 00:09:48 +01:00
parent 1c445f88f6
commit 66104da10c
9 changed files with 62 additions and 24 deletions

View file

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
from synapse.federation.units import Pdu
from synapse.api.events.utils import prune_pdu from synapse.api.events.utils import prune_pdu
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
from syutil.base64util import encode_base64, decode_base64 from syutil.base64util import encode_base64, decode_base64
@ -25,8 +26,7 @@ import hashlib
def hash_event_pdu(pdu, hash_algortithm=hashlib.sha256): def hash_event_pdu(pdu, hash_algortithm=hashlib.sha256):
hashed = _compute_hash(pdu, hash_algortithm) hashed = _compute_hash(pdu, hash_algortithm)
hashes[hashed.name] = encode_base64(hashed.digest()) pdu.hashes[hashed.name] = encode_base64(hashed.digest())
pdu.hashes = hashes
return pdu return pdu

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from .units import Pdu from .units import Pdu
from synapse.crypto.event_signing import hash_event_pdu, sign_event_pdu
import copy import copy
@ -33,6 +34,7 @@ def encode_event_id(pdu_id, origin):
class PduCodec(object): class PduCodec(object):
def __init__(self, hs): def __init__(self, hs):
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname self.server_name = hs.hostname
self.event_factory = hs.get_event_factory() self.event_factory = hs.get_event_factory()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -99,4 +101,6 @@ class PduCodec(object):
if "ts" not in kwargs: if "ts" not in kwargs:
kwargs["ts"] = int(self.clock.time_msec()) kwargs["ts"] = int(self.clock.time_msec())
return Pdu(**kwargs) pdu = Pdu(**kwargs)
pdu = hash_event_pdu(pdu)
return sign_event_pdu(pdu, self.server_name, self.signing_key)

View file

@ -42,6 +42,7 @@ from .transactions import TransactionStore
from .keys import KeyStore from .keys import KeyStore
from .signatures import SignatureStore from .signatures import SignatureStore
from syutil.base64util import decode_base64
import json import json
import logging import logging
@ -168,11 +169,11 @@ class DataStore(RoomMemberStore, RoomStore,
txn, pdu.pdu_id, pdu.origin, hash_alg, hash_bytes, txn, pdu.pdu_id, pdu.origin, hash_alg, hash_bytes,
) )
signatures = pdu.sigatures.get(pdu.orgin, {}) signatures = pdu.signatures.get(pdu.origin, {})
for key_id, signature_base64 in signatures: for key_id, signature_base64 in signatures.items():
signature_bytes = decode_base64(signature_base64) signature_bytes = decode_base64(signature_base64)
self.store_pdu_origin_signatures_txn( self._store_pdu_origin_signature_txn(
txn, pdu.pdu_id, pdu.origin, key_id, signature_bytes, txn, pdu.pdu_id, pdu.origin, key_id, signature_bytes,
) )

View file

@ -47,7 +47,7 @@ class SignatureStore(SQLBaseStore):
algorithm (str): Hashing algorithm. algorithm (str): Hashing algorithm.
hash_bytes (bytes): Hash function output bytes. hash_bytes (bytes): Hash function output bytes.
""" """
self._simple_insert_txn(self, txn, "pdu_hashes", { self._simple_insert_txn(txn, "pdu_hashes", {
"pdu_id": pdu_id, "pdu_id": pdu_id,
"origin": origin, "origin": origin,
"algorithm": algorithm, "algorithm": algorithm,
@ -66,7 +66,7 @@ class SignatureStore(SQLBaseStore):
query = ( query = (
"SELECT key_id, signature" "SELECT key_id, signature"
" FROM pdu_origin_signatures" " FROM pdu_origin_signatures"
" WHERE WHERE pdu_id = ? and origin = ?" " WHERE pdu_id = ? and origin = ?"
) )
txn.execute(query, (pdu_id, origin)) txn.execute(query, (pdu_id, origin))
return dict(txn.fetchall()) return dict(txn.fetchall())
@ -81,7 +81,7 @@ class SignatureStore(SQLBaseStore):
key_id (str): Id for the signing key. key_id (str): Id for the signing key.
signature (bytes): The signature. signature (bytes): The signature.
""" """
self._simple_insert_txn(self, txn, "pdu_origin_signatures", { self._simple_insert_txn(txn, "pdu_origin_signatures", {
"pdu_id": pdu_id, "pdu_id": pdu_id,
"origin": origin, "origin": origin,
"key_id": key_id, "key_id": key_id,

View file

@ -23,14 +23,21 @@ from synapse.federation.units import Pdu
from synapse.server import HomeServer from synapse.server import HomeServer
from mock import Mock from mock import Mock, NonCallableMock
from ..utils import MockKey
class PduCodecTestCase(unittest.TestCase): class PduCodecTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.hs = HomeServer("blargle.net") self.mock_config = NonCallableMock()
self.event_factory = self.hs.get_event_factory() self.mock_config.signing_key = [MockKey()]
self.hs = HomeServer(
"blargle.net",
config=self.mock_config,
)
self.event_factory = self.hs.get_event_factory()
self.codec = PduCodec(self.hs) self.codec = PduCodec(self.hs)
def test_decode_event_id(self): def test_decode_event_id(self):

View file

@ -28,7 +28,7 @@ from synapse.server import HomeServer
# python imports # python imports
import json import json
from ..utils import MockHttpResource, MemoryDataStore from ..utils import MockHttpResource, MemoryDataStore, MockKey
from .utils import RestTestCase from .utils import RestTestCase
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
@ -122,6 +122,9 @@ class EventStreamPermissionsTestCase(RestTestCase):
persistence_service = Mock(spec=["get_latest_pdus_in_context"]) persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = [] persistence_service.get_latest_pdus_in_context.return_value = []
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer( hs = HomeServer(
"test", "test",
db_pool=None, db_pool=None,
@ -139,7 +142,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),
config=NonCallableMock(), config=self.mock_config,
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)

View file

@ -18,9 +18,9 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from mock import Mock from mock import Mock, NonCallableMock
from ..utils import MockHttpResource from ..utils import MockHttpResource, MockKey
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.server import HomeServer from synapse.server import HomeServer
@ -41,6 +41,9 @@ class ProfileTestCase(unittest.TestCase):
"set_avatar_url", "set_avatar_url",
]) ])
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer("test", hs = HomeServer("test",
db_pool=None, db_pool=None,
http_client=None, http_client=None,
@ -48,6 +51,7 @@ class ProfileTestCase(unittest.TestCase):
federation=Mock(), federation=Mock(),
replication_layer=Mock(), replication_layer=Mock(),
datastore=None, datastore=None,
config=self.mock_config,
) )
def _get_user_by_req(request=None): def _get_user_by_req(request=None):

View file

@ -27,7 +27,7 @@ from synapse.server import HomeServer
import json import json
import urllib import urllib
from ..utils import MockHttpResource, MemoryDataStore from ..utils import MockHttpResource, MemoryDataStore, MockKey
from .utils import RestTestCase from .utils import RestTestCase
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
@ -50,6 +50,9 @@ class RoomPermissionsTestCase(RestTestCase):
persistence_service = Mock(spec=["get_latest_pdus_in_context"]) persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = [] persistence_service.get_latest_pdus_in_context.return_value = []
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer( hs = HomeServer(
"red", "red",
db_pool=None, db_pool=None,
@ -61,7 +64,7 @@ class RoomPermissionsTestCase(RestTestCase):
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),
config=NonCallableMock(), config=self.mock_config,
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
@ -408,6 +411,9 @@ class RoomsMemberListTestCase(RestTestCase):
persistence_service = Mock(spec=["get_latest_pdus_in_context"]) persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = [] persistence_service.get_latest_pdus_in_context.return_value = []
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer( hs = HomeServer(
"red", "red",
db_pool=None, db_pool=None,
@ -419,7 +425,7 @@ class RoomsMemberListTestCase(RestTestCase):
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),
config=NonCallableMock(), config=self.mock_config,
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
@ -497,6 +503,9 @@ class RoomsCreateTestCase(RestTestCase):
persistence_service = Mock(spec=["get_latest_pdus_in_context"]) persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = [] persistence_service.get_latest_pdus_in_context.return_value = []
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer( hs = HomeServer(
"red", "red",
db_pool=None, db_pool=None,
@ -508,7 +517,7 @@ class RoomsCreateTestCase(RestTestCase):
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),
config=NonCallableMock(), config=self.mock_config,
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
@ -598,6 +607,9 @@ class RoomTopicTestCase(RestTestCase):
persistence_service = Mock(spec=["get_latest_pdus_in_context"]) persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = [] persistence_service.get_latest_pdus_in_context.return_value = []
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer( hs = HomeServer(
"red", "red",
db_pool=None, db_pool=None,
@ -609,7 +621,7 @@ class RoomTopicTestCase(RestTestCase):
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),
config=NonCallableMock(), config=self.mock_config,
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
@ -712,6 +724,9 @@ class RoomMemberStateTestCase(RestTestCase):
persistence_service = Mock(spec=["get_latest_pdus_in_context"]) persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = [] persistence_service.get_latest_pdus_in_context.return_value = []
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer( hs = HomeServer(
"red", "red",
db_pool=None, db_pool=None,
@ -723,7 +738,7 @@ class RoomMemberStateTestCase(RestTestCase):
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),
config=NonCallableMock(), config=self.mock_config,
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
@ -853,6 +868,9 @@ class RoomMessagesTestCase(RestTestCase):
persistence_service = Mock(spec=["get_latest_pdus_in_context"]) persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = [] persistence_service.get_latest_pdus_in_context.return_value = []
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer( hs = HomeServer(
"red", "red",
db_pool=None, db_pool=None,
@ -864,7 +882,7 @@ class RoomMessagesTestCase(RestTestCase):
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),
config=NonCallableMock(), config=self.mock_config,
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)

View file

@ -118,13 +118,14 @@ class MockHttpResource(HttpServer):
class MockKey(object): class MockKey(object):
alg = "mock_alg" alg = "mock_alg"
version = "mock_version" version = "mock_version"
signature = b"\x9a\x87$"
@property @property
def verify_key(self): def verify_key(self):
return self return self
def sign(self, message): def sign(self, message):
return b"\x9a\x87$" return self
def verify(self, message, sig): def verify(self, message, sig):
assert sig == b"\x9a\x87$" assert sig == b"\x9a\x87$"