Begin implementing state groups.

This commit is contained in:
Erik Johnston 2014-10-14 16:59:51 +01:00
parent 636a0dbde7
commit 5fefc12d1e
3 changed files with 123 additions and 3 deletions

View file

@ -35,7 +35,7 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
class StateHandler(object): class StateHandler(object):
""" Repsonsible for doing state conflict resolution. """ Responsible for doing state conflict resolution.
""" """
def __init__(self, hs): def __init__(self, hs):
@ -50,7 +50,7 @@ class StateHandler(object):
to update the state and b) works out what the prev_state should be. to update the state and b) works out what the prev_state should be.
Returns: Returns:
Deferred: Resolved with a boolean indicating if we succesfully Deferred: Resolved with a boolean indicating if we successfully
updated the state. updated the state.
Raised: Raised:
@ -83,6 +83,8 @@ class StateHandler(object):
current_state.pdu_id, current_state.origin current_state.pdu_id, current_state.origin
) )
yield self.update_state_groups(event)
# TODO check current_state to see if the min power level is less # TODO check current_state to see if the min power level is less
# than the power level of the user # than the power level of the user
# power_level = self._get_power_level_for_event(event) # power_level = self._get_power_level_for_event(event)
@ -128,6 +130,87 @@ class StateHandler(object):
defer.returnValue(is_new) defer.returnValue(is_new)
@defer.inlineCallbacks
def update_state_groups(self, event):
state_groups = yield self.store.get_state_groups(
event.prev_events
)
if len(state_groups) == 1 and not hasattr(event, "state_key"):
event.state_group = state_groups[0].group
event.current_state = state_groups[0].state
return
state = {}
state_sets = {}
for group in state_groups:
for s in group.state:
state.setdefault((s.type, s.state_key), []).add(s)
state_sets.setdefault(
(s.type, s.state_key),
set()
).add(s.event_id)
unconflicted_state = {
k: v.pop() for k, v in state_sets.items()
if len(v) == 1
}
conflicted_state = {
k: state[k]
for k, v in state_sets.items()
if len(v) > 1
}
new_state = {}
new_state.update(unconflicted_state)
for key, events in conflicted_state.items():
new_state[key] = yield self.resolve(events)
if hasattr(event, "state_key"):
new_state[(event.type, event.state_key)] = event
event.state_group = None
event.current_state = new_state.values()
@defer.inlineCallbacks
def resolve(self, events):
curr_events = events
new_powers_deferreds = []
for e in curr_events:
new_powers_deferreds.append(
self.store.get_power_level(e.context, e.user_id)
)
new_powers = yield defer.gatherResults(
new_powers_deferreds,
consumeErrors=True
)
max_power = max([int(p) for p in new_powers])
curr_events = [
z[0] for z in zip(curr_events, new_powers)
if int(z[1]) == max_power
]
if not curr_events:
raise RuntimeError("Max didn't get a max?")
elif len(curr_events) == 1:
defer.returnValue(curr_events[0])
# TODO: For now, just choose the one with the largest event_id.
defer.returnValue(
sorted(
curr_events,
key=lambda e: hashlib.sha1(
e.event_id + e.user_id + e.room_id + e.type
).hexdigest()
)[0]
)
def _get_power_level_for_event(self, event): def _get_power_level_for_event(self, event):
# return self._persistence.get_power_level_for_user(event.room_id, # return self._persistence.get_power_level_for_user(event.room_id,
# event.sender) # event.sender)

View file

@ -40,6 +40,7 @@ from .stream import StreamStore
from .pdu import StatePduStore, PduStore, PdusTable from .pdu import StatePduStore, PduStore, PdusTable
from .transactions import TransactionStore from .transactions import TransactionStore
from .keys import KeyStore from .keys import KeyStore
from .state import StateStore
import json import json
import logging import logging
@ -59,6 +60,7 @@ SCHEMAS = [
"room_aliases", "room_aliases",
"keys", "keys",
"redactions", "redactions",
"state",
] ]
@ -76,7 +78,7 @@ class _RollbackButIsFineException(Exception):
class DataStore(RoomMemberStore, RoomStore, class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, FeedbackStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
PresenceStore, PduStore, StatePduStore, TransactionStore, PresenceStore, PduStore, StatePduStore, TransactionStore,
DirectoryStore, KeyStore): DirectoryStore, KeyStore, StateStore):
def __init__(self, hs): def __init__(self, hs):
super(DataStore, self).__init__(hs) super(DataStore, self).__init__(hs)
@ -222,6 +224,8 @@ class DataStore(RoomMemberStore, RoomStore,
) )
raise _RollbackButIsFineException("_persist_event") raise _RollbackButIsFineException("_persist_event")
self._store_state_groups_txn(txn, event)
is_state = hasattr(event, "state_key") and event.state_key is not None is_state = hasattr(event, "state_key") and event.state_key is not None
if is_new_state and is_state: if is_new_state and is_state:
vals = { vals = {

View file

@ -0,0 +1,33 @@
/* Copyright 2014 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS state_groups(
id INTEGER PRIMARY KEY AUTOINCREMENT,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS state_groups_state(
state_group INTEGER NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
event_id TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS event_to_state_groups(
event_id TEXT NOT NULL,
state_group INTEGER NOT NULL
);