persist hashes and origin signatures for PDUs

This commit is contained in:
Mark Haines 2014-10-15 17:09:04 +01:00
parent 27d0c1ecc2
commit 1c445f88f6
7 changed files with 135 additions and 15 deletions

View file

@ -27,7 +27,14 @@ def prune_event(event):
the user has specified, but we do want to keep necessary information like the user has specified, but we do want to keep necessary information like
type, state_key etc. type, state_key etc.
""" """
return _prune_event_or_pdu(event.type, event)
def prune_pdu(pdu):
"""Removes keys that contain unrestricted and non-essential data from a PDU
"""
return _prune_event_or_pdu(pdu.pdu_type, pdu)
def _prune_event_or_pdu(event_type, event):
# Remove all extraneous fields. # Remove all extraneous fields.
event.unrecognized_keys = {} event.unrecognized_keys = {}
@ -38,25 +45,25 @@ def prune_event(event):
if field in event.content: if field in event.content:
new_content[field] = event.content[field] new_content[field] = event.content[field]
if event.type == RoomMemberEvent.TYPE: if event_type == RoomMemberEvent.TYPE:
add_fields("membership") add_fields("membership")
elif event.type == RoomCreateEvent.TYPE: elif event_type == RoomCreateEvent.TYPE:
add_fields("creator") add_fields("creator")
elif event.type == RoomJoinRulesEvent.TYPE: elif event_type == RoomJoinRulesEvent.TYPE:
add_fields("join_rule") add_fields("join_rule")
elif event.type == RoomPowerLevelsEvent.TYPE: elif event_type == RoomPowerLevelsEvent.TYPE:
# TODO: Actually check these are valid user_ids etc. # TODO: Actually check these are valid user_ids etc.
add_fields("default") add_fields("default")
for k, v in event.content.items(): for k, v in event.content.items():
if k.startswith("@") and isinstance(v, (int, long)): if k.startswith("@") and isinstance(v, (int, long)):
new_content[k] = v new_content[k] = v
elif event.type == RoomAddStateLevelEvent.TYPE: elif event_type == RoomAddStateLevelEvent.TYPE:
add_fields("level") add_fields("level")
elif event.type == RoomSendEventLevelEvent.TYPE: elif event_type == RoomSendEventLevelEvent.TYPE:
add_fields("level") add_fields("level")
elif event.type == RoomOpsPowerLevelsEvent.TYPE: elif event_type == RoomOpsPowerLevelsEvent.TYPE:
add_fields("kick_level", "ban_level", "redact_level") add_fields("kick_level", "ban_level", "redact_level")
elif event.type == RoomAliasesEvent.TYPE: elif event_type == RoomAliasesEvent.TYPE:
add_fields("aliases") add_fields("aliases")
event.content = new_content event.content = new_content

View file

@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.events.utils import prune_pdu
from syutil.jsonutil import encode_canonical_json
from syutil.base64util import encode_base64, decode_base64
from syutil.crypto.jsonsign import sign_json, verify_signed_json
import hashlib
def hash_event_pdu(pdu, hash_algortithm=hashlib.sha256):
hashed = _compute_hash(pdu, hash_algortithm)
hashes[hashed.name] = encode_base64(hashed.digest())
pdu.hashes = hashes
return pdu
def check_event_pdu_hash(pdu, hash_algorithm=hashlib.sha256):
"""Check whether the hash for this PDU matches the contents"""
computed_hash = _compute_hash(pdu, hash_algortithm)
if computed_hash.name not in pdu.hashes:
raise Exception("Algorithm %s not in hashes %s" % (
computed_hash.name, list(pdu.hashes)
))
message_hash_base64 = hashes[computed_hash.name]
try:
message_hash_bytes = decode_base64(message_hash_base64)
except:
raise Exception("Invalid base64: %s" % (message_hash_base64,))
return message_hash_bytes == computed_hash.digest()
def _compute_hash(pdu, hash_algorithm):
pdu_json = pdu.get_dict()
pdu_json.pop("meta", None)
pdu_json.pop("signatures", None)
hashes = pdu_json.pop("hashes", {})
pdu_json_bytes = encode_canonical_json(pdu_json)
return hash_algorithm(pdu_json_bytes)
def sign_event_pdu(pdu, signature_name, signing_key):
tmp_pdu = Pdu(**pdu.get_dict())
tmp_pdu = prune_pdu(tmp_pdu)
pdu_json = tmp_pdu.get_dict()
pdu_jdon = sign_json(pdu_json, signature_name, signing_key)
pdu.signatures = pdu_json["signatures"]
return pdu
def verify_signed_event_pdu(pdu, signature_name, verify_key):
tmp_pdu = Pdu(**pdu.get_dict())
tmp_pdu = prune_pdu(tmp_pdu)
pdu_json = tmp_pdu.get_dict()
verify_signed_json(pdu_json, signature_name, verify_key)

View file

@ -18,6 +18,7 @@ server protocol.
""" """
from synapse.util.jsonobject import JsonEncodedObject from synapse.util.jsonobject import JsonEncodedObject
from syutil.base64util import encode_base64
import logging import logging
import json import json
@ -63,6 +64,8 @@ class Pdu(JsonEncodedObject):
"depth", "depth",
"content", "content",
"outlier", "outlier",
"hashes",
"signatures",
"is_state", # Below this are keys valid only for State Pdus. "is_state", # Below this are keys valid only for State Pdus.
"state_key", "state_key",
"power_level", "power_level",
@ -91,7 +94,7 @@ class Pdu(JsonEncodedObject):
# just leaving it as a dict. (OR DO WE?!) # just leaving it as a dict. (OR DO WE?!)
def __init__(self, destinations=[], is_state=False, prev_pdus=[], def __init__(self, destinations=[], is_state=False, prev_pdus=[],
outlier=False, **kwargs): outlier=False, hashes={}, signatures={}, **kwargs):
if is_state: if is_state:
for required_key in ["state_key"]: for required_key in ["state_key"]:
if required_key not in kwargs: if required_key not in kwargs:
@ -102,6 +105,8 @@ class Pdu(JsonEncodedObject):
is_state=is_state, is_state=is_state,
prev_pdus=prev_pdus, prev_pdus=prev_pdus,
outlier=outlier, outlier=outlier,
hashes=hashes,
signatures=signatures,
**kwargs **kwargs
) )
@ -126,6 +131,16 @@ class Pdu(JsonEncodedObject):
if "unrecognized_keys" in d and d["unrecognized_keys"]: if "unrecognized_keys" in d and d["unrecognized_keys"]:
args.update(json.loads(d["unrecognized_keys"])) args.update(json.loads(d["unrecognized_keys"]))
hashes = {
alg: encode_base64(hsh)
for alg, hsh in pdu_tuple.hashes.items()
}
signatures = {
kid: encode_base64(sig)
for kid, sig in pdu_tuple.signatures.items()
}
return Pdu( return Pdu(
prev_pdus=pdu_tuple.prev_pdu_list, prev_pdus=pdu_tuple.prev_pdu_list,
**args **args

View file

@ -40,6 +40,8 @@ from .stream import StreamStore
from .pdu import StatePduStore, PduStore, PdusTable from .pdu import StatePduStore, PduStore, PdusTable
from .transactions import TransactionStore from .transactions import TransactionStore
from .keys import KeyStore from .keys import KeyStore
from .signatures import SignatureStore
import json import json
import logging import logging
@ -59,6 +61,7 @@ SCHEMAS = [
"room_aliases", "room_aliases",
"keys", "keys",
"redactions", "redactions",
"signatures",
] ]
@ -76,7 +79,7 @@ class _RollbackButIsFineException(Exception):
class DataStore(RoomMemberStore, RoomStore, class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, FeedbackStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
PresenceStore, PduStore, StatePduStore, TransactionStore, PresenceStore, PduStore, StatePduStore, TransactionStore,
DirectoryStore, KeyStore): DirectoryStore, KeyStore, SignatureStore):
def __init__(self, hs): def __init__(self, hs):
super(DataStore, self).__init__(hs) super(DataStore, self).__init__(hs)
@ -144,6 +147,8 @@ class DataStore(RoomMemberStore, RoomStore,
def _persist_event_pdu_txn(self, txn, pdu): def _persist_event_pdu_txn(self, txn, pdu):
cols = dict(pdu.__dict__) cols = dict(pdu.__dict__)
unrec_keys = dict(pdu.unrecognized_keys) unrec_keys = dict(pdu.unrecognized_keys)
del cols["hashes"]
del cols["signatures"]
del cols["content"] del cols["content"]
del cols["prev_pdus"] del cols["prev_pdus"]
cols["content_json"] = json.dumps(pdu.content) cols["content_json"] = json.dumps(pdu.content)
@ -157,6 +162,20 @@ class DataStore(RoomMemberStore, RoomStore,
logger.debug("Persisting: %s", repr(cols)) logger.debug("Persisting: %s", repr(cols))
for hash_alg, hash_base64 in pdu.hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_pdu_hash_txn(
txn, pdu.pdu_id, pdu.origin, hash_alg, hash_bytes,
)
signatures = pdu.sigatures.get(pdu.orgin, {})
for key_id, signature_base64 in signatures:
signature_bytes = decode_base64(signature_base64)
self.store_pdu_origin_signatures_txn(
txn, pdu.pdu_id, pdu.origin, key_id, signature_bytes,
)
if pdu.is_state: if pdu.is_state:
self._persist_state_txn(txn, pdu.prev_pdus, cols) self._persist_state_txn(txn, pdu.prev_pdus, cols)
else: else:

View file

@ -64,6 +64,11 @@ class PduStore(SQLBaseStore):
for r in PduEdgesTable.decode_results(txn.fetchall()) for r in PduEdgesTable.decode_results(txn.fetchall())
] ]
hashes = self._get_pdu_hashes_txn(txn, pdu_id, origin)
signatures = self._get_pdu_origin_signatures_txn(
txn, pdu_id, origin
)
query = ( query = (
"SELECT %(fields)s FROM %(pdus)s as p " "SELECT %(fields)s FROM %(pdus)s as p "
"LEFT JOIN %(state)s as s " "LEFT JOIN %(state)s as s "
@ -80,7 +85,9 @@ class PduStore(SQLBaseStore):
row = txn.fetchone() row = txn.fetchone()
if row: if row:
results.append(PduTuple(PduEntry(*row), edges)) results.append(PduTuple(
PduEntry(*row), edges, hashes, signatures
))
return results return results
@ -908,7 +915,7 @@ This does not include a prev_pdus key.
PduTuple = namedtuple( PduTuple = namedtuple(
"PduTuple", "PduTuple",
("pdu_entry", "prev_pdu_list") ("pdu_entry", "prev_pdu_list", "hashes", "signatures")
) )
""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent """ This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
the `prev_pdus` key of a PDU. the `prev_pdus` key of a PDU.

View file

@ -28,9 +28,9 @@ CREATE TABLE IF NOT EXISTS pdu_origin_signatures (
origin TEXT, origin TEXT,
key_id TEXT, key_id TEXT,
signature BLOB, signature BLOB,
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, algorithm) CONSTRAINT uniqueness UNIQUE (pdu_id, origin, key_id)
); );
CREATE INDEX IF NOT EXISTS pdu_origin_signatures_id ON pdu_origin_signatures ( CREATE INDEX IF NOT EXISTS pdu_origin_signatures_id ON pdu_origin_signatures (
pdu_id, origin, pdu_id, origin
); );

View file

@ -41,7 +41,7 @@ def make_pdu(prev_pdus=[], **kwargs):
} }
pdu_fields.update(kwargs) pdu_fields.update(kwargs)
return PduTuple(PduEntry(**pdu_fields), prev_pdus) return PduTuple(PduEntry(**pdu_fields), prev_pdus, {}, {})
class FederationTestCase(unittest.TestCase): class FederationTestCase(unittest.TestCase):
@ -183,6 +183,8 @@ class FederationTestCase(unittest.TestCase):
"is_state": False, "is_state": False,
"content": {"testing": "content here"}, "content": {"testing": "content here"},
"depth": 1, "depth": 1,
"hashes": {},
"signatures": {},
}, },
] ]
}, },