Merge remote-tracking branch 'origin/develop' into markjh/3pid

This commit is contained in:
Mark Haines 2016-01-29 14:15:12 +00:00
commit 47374a33fc
17 changed files with 583 additions and 349 deletions

View file

@ -15,6 +15,8 @@
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID from synapse.types import UserID, RoomID
import ujson as json
class Filtering(object): class Filtering(object):
@ -149,6 +151,9 @@ class FilterCollection(object):
"include_leave", False "include_leave", False
) )
def __repr__(self):
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
def get_filter_json(self): def get_filter_json(self):
return self._filter_json return self._filter_json

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError, AuthError, Codes from synapse.api.errors import AuthError, Codes
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
@ -119,9 +119,12 @@ class MessageHandler(BaseHandler):
if source_config.direction == 'b': if source_config.direction == 'b':
# if we're going backwards, we might need to backfill. This # if we're going backwards, we might need to backfill. This
# requires that we have a topo token. # requires that we have a topo token.
if room_token.topological is None: if room_token.topological:
raise SynapseError(400, "Invalid token: cannot paginate " max_topo = room_token.topological
"backwards from a stream token") else:
max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
room_id, room_token.stream
)
if membership == Membership.LEAVE: if membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before # If they have left the room then clamp the token to be before
@ -131,11 +134,11 @@ class MessageHandler(BaseHandler):
member_event_id member_event_id
) )
leave_token = RoomStreamToken.parse(leave_token) leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < room_token.topological: if leave_token.topological < max_topo:
source_config.from_key = str(leave_token) source_config.from_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill( yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, room_token.topological room_id, max_topo
) )
events, next_key = yield data_source.get_pagination_rows( events, next_key = yield data_source.get_pagination_rows(

View file

@ -72,7 +72,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
) )
class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [ class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
"room_id", # str "room_id", # str
"timeline", # TimelineBatch "timeline", # TimelineBatch
"state", # dict[(str, str), FrozenEvent] "state", # dict[(str, str), FrozenEvent]
@ -298,46 +298,19 @@ class SyncHandler(BaseHandler):
room_id, sync_config, now_token, since_token=timeline_since_token room_id, sync_config, now_token, since_token=timeline_since_token
) )
notifs = yield self.unread_notifs_for_room_id( room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, ephemeral_by_room room_id, sync_config,
now_token=now_token,
since_token=timeline_since_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
all_ephemeral_by_room=ephemeral_by_room,
batch=batch,
full_state=True,
) )
unread_notifications = {} defer.returnValue(room_sync)
if notifs is not None:
unread_notifications["notification_count"] = len(notifs)
unread_notifications["highlight_count"] = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
current_state = yield self.get_state_at(room_id, now_token)
current_state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
current_state.values()
)
}
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
account_data
)
ephemeral = sync_config.filter_collection.filter_room_ephemeral(
ephemeral_by_room.get(room_id, [])
)
defer.returnValue(JoinedSyncResult(
room_id=room_id,
timeline=batch,
state=current_state,
ephemeral=ephemeral,
account_data=account_data,
unread_notifications=unread_notifications,
))
def account_data_for_user(self, account_data): def account_data_for_user(self, account_data):
account_data_events = [] account_data_events = []
@ -429,44 +402,20 @@ class SyncHandler(BaseHandler):
defer.returnValue((now_token, ephemeral_by_room)) defer.returnValue((now_token, ephemeral_by_room))
@defer.inlineCallbacks
def full_state_sync_for_archived_room(self, room_id, sync_config, def full_state_sync_for_archived_room(self, room_id, sync_config,
leave_event_id, leave_token, leave_event_id, leave_token,
timeline_since_token, tags_by_room, timeline_since_token, tags_by_room,
account_data_by_room): account_data_by_room):
"""Sync a room for a client which is starting without any state """Sync a room for a client which is starting without any state
Returns: Returns:
A Deferred JoinedSyncResult. A Deferred ArchivedSyncResult.
""" """
batch = yield self.load_filtered_recents( return self.incremental_sync_for_archived_room(
room_id, sync_config, leave_token, since_token=timeline_since_token sync_config, room_id, leave_event_id, timeline_since_token, tags_by_room,
account_data_by_room, full_state=True, leave_token=leave_token,
) )
leave_state = yield self.store.get_state_for_event(leave_event_id)
leave_state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
leave_state.values()
)
}
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
account_data = sync_config.filter_collection.filter_room_account_data(
account_data
)
defer.returnValue(ArchivedSyncResult(
room_id=room_id,
timeline=batch,
state=leave_state,
account_data=account_data,
))
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_with_gap(self, sync_config, since_token): def incremental_sync_with_gap(self, sync_config, since_token):
""" Get the incremental delta needed to bring the client up to """ Get the incremental delta needed to bring the client up to
@ -512,154 +461,127 @@ class SyncHandler(BaseHandler):
sync_config.user sync_config.user
) )
user_id = sync_config.user.to_string()
timeline_limit = sync_config.filter_collection.timeline_limit() timeline_limit = sync_config.filter_collection.timeline_limit()
room_events, _ = yield self.store.get_room_events_stream(
sync_config.user.to_string(),
from_key=since_token.room_key,
to_key=now_token.room_key,
limit=timeline_limit + 1,
)
tags_by_room = yield self.store.get_updated_tags( tags_by_room = yield self.store.get_updated_tags(
sync_config.user.to_string(), user_id,
since_token.account_data_key, since_token.account_data_key,
) )
account_data, account_data_by_room = ( account_data, account_data_by_room = (
yield self.store.get_updated_account_data_for_user( yield self.store.get_updated_account_data_for_user(
sync_config.user.to_string(), user_id,
since_token.account_data_key, since_token.account_data_key,
) )
) )
joined = [] # Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_room_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
mem_change_events_by_room_id = {}
for event in rooms_changed:
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
newly_joined_rooms = []
archived = [] archived = []
if len(room_events) <= timeline_limit: invited = []
# There is no gap in any of the rooms. Therefore we can just for room_id, events in mem_change_events_by_room_id.items():
# partition the new events by room and return them. non_joins = [e for e in events if e.membership != Membership.JOIN]
logger.debug("Got %i events for incremental sync - not limited", has_join = len(non_joins) != len(events)
len(room_events))
invite_events = [] # We want to figure out if we joined the room at some point since
leave_events = [] # the last sync (even if we have since left). This is to make sure
events_by_room_id = {} # we do send down the room, and with full state, where necessary
for event in room_events: if room_id in joined_room_ids or has_join:
events_by_room_id.setdefault(event.room_id, []).append(event) old_state = yield self.get_state_at(room_id, since_token)
if event.room_id not in joined_room_ids: old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
if (event.type == EventTypes.Member if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
and event.state_key == sync_config.user.to_string()): newly_joined_rooms.append(room_id)
if event.membership == Membership.INVITE:
invite_events.append(event)
elif event.membership in (Membership.LEAVE, Membership.BAN):
leave_events.append(event)
for room_id in joined_room_ids: if room_id in joined_room_ids:
recents = events_by_room_id.get(room_id, []) continue
logger.debug("Events for room %s: %r", room_id, recents)
state = {
(event.type, event.state_key): event
for event in recents if event.is_state()}
limited = False
if recents: if not non_joins:
prev_batch = now_token.copy_and_replace( continue
"room_key", recents[0].internal_metadata.before
)
else:
prev_batch = now_token
just_joined = yield self.check_joined_room(sync_config, state)
if just_joined:
logger.debug("User has just joined %s: needs full state",
room_id)
state = yield self.get_state_at(room_id, now_token)
# the timeline is inherently limited if we've just joined
limited = True
recents = sync_config.filter_collection.filter_room_timeline(recents)
state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
state.values()
)
}
acc_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
acc_data = sync_config.filter_collection.filter_room_account_data(
acc_data
)
ephemeral = sync_config.filter_collection.filter_room_ephemeral(
ephemeral_by_room.get(room_id, [])
)
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=TimelineBatch(
events=recents,
prev_batch=prev_batch,
limited=limited,
),
state=state,
ephemeral=ephemeral,
account_data=acc_data,
unread_notifications={},
)
logger.debug("Result for room %s: %r", room_id, room_sync)
# Only bother if we're still currently invited
should_invite = non_joins[-1].membership == Membership.INVITE
if should_invite:
room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
if room_sync: if room_sync:
notifs = yield self.unread_notifs_for_room_id( invited.append(room_sync)
room_id, sync_config, all_ephemeral_by_room
)
if notifs is not None: # Always include leave/ban events. Just take the last one.
notif_dict = room_sync.unread_notifications # TODO: How do we handle ban -> leave in same batch?
notif_dict["notification_count"] = len(notifs) leave_events = [
notif_dict["highlight_count"] = len([ e for e in non_joins
1 for notif in notifs if e.membership in (Membership.LEAVE, Membership.BAN)
if _action_has_highlight(notif["actions"]) ]
])
joined.append(room_sync) if leave_events:
leave_event = leave_events[-1]
else:
logger.debug("Got %i events for incremental sync - hit limit",
len(room_events))
invite_events = yield self.store.get_invites_for_user(
sync_config.user.to_string()
)
leave_events = yield self.store.get_leave_and_ban_events_for_user(
sync_config.user.to_string()
)
for room_id in joined_room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token,
ephemeral_by_room, tags_by_room, account_data_by_room,
all_ephemeral_by_room=all_ephemeral_by_room,
)
if room_sync:
joined.append(room_sync)
for leave_event in leave_events:
room_sync = yield self.incremental_sync_for_archived_room( room_sync = yield self.incremental_sync_for_archived_room(
sync_config, leave_event, since_token, tags_by_room, sync_config, room_id, leave_event.event_id, since_token,
account_data_by_room tags_by_room, account_data_by_room,
full_state=room_id in newly_joined_rooms
) )
if room_sync: if room_sync:
archived.append(room_sync) archived.append(room_sync)
invited = [ # Get all events for rooms we're currently joined to.
InvitedSyncResult(room_id=event.room_id, invite=event) room_to_events = yield self.store.get_room_events_stream_for_rooms(
for event in invite_events room_ids=joined_room_ids,
] from_key=since_token.room_key,
to_key=now_token.room_key,
limit=timeline_limit + 1,
)
joined = []
# We loop through all room ids, even if there are no new events, in case
# there are non room events taht we need to notify about.
for room_id in joined_room_ids:
room_entry = room_to_events.get(room_id, None)
if room_entry:
events, start_key = room_entry
prev_batch_token = now_token.copy_and_replace("room_key", start_key)
newly_joined_room = room_id in newly_joined_rooms
full_state = newly_joined_room
batch = yield self.load_filtered_recents(
room_id, sync_config, prev_batch_token,
since_token=since_token,
recents=events,
newly_joined_room=newly_joined_room,
)
else:
batch = TimelineBatch(
events=[],
prev_batch=since_token,
limited=False,
)
full_state = False
room_sync = yield self.incremental_sync_with_gap_for_room(
room_id=room_id,
sync_config=sync_config,
since_token=since_token,
now_token=now_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
all_ephemeral_by_room=all_ephemeral_by_room,
batch=batch,
full_state=full_state,
)
if room_sync:
joined.append(room_sync)
account_data_for_user = sync_config.filter_collection.filter_account_data( account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data) self.account_data_for_user(account_data)
@ -680,28 +602,40 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token, def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None): since_token=None, recents=None, newly_joined_room=False):
""" """
:returns a Deferred TimelineBatch :returns a Deferred TimelineBatch
""" """
limited = True
recents = []
filtering_factor = 2 filtering_factor = 2
timeline_limit = sync_config.filter_collection.timeline_limit() timeline_limit = sync_config.filter_collection.timeline_limit()
load_limit = max(timeline_limit * filtering_factor, 100) load_limit = max(timeline_limit * filtering_factor, 10)
max_repeat = 3 # Only try a few times per room, otherwise max_repeat = 5 # Only try a few times per room, otherwise
room_key = now_token.room_key room_key = now_token.room_key
end_key = room_key end_key = room_key
limited = recents is None or newly_joined_room or timeline_limit < len(recents)
if recents is not None:
recents = sync_config.filter_collection.filter_room_timeline(recents)
recents = yield self._filter_events_for_client(
sync_config.user.to_string(),
recents,
is_peeking=sync_config.is_guest,
)
else:
recents = []
since_key = None
if since_token and not newly_joined_room:
since_key = since_token.room_key
while limited and len(recents) < timeline_limit and max_repeat: while limited and len(recents) < timeline_limit and max_repeat:
events, keys = yield self.store.get_recent_events_for_room( events, end_key = yield self.store.get_room_events_stream_for_room(
room_id, room_id,
limit=load_limit + 1, limit=load_limit + 1,
from_token=since_token.room_key if since_token else None, from_key=since_key,
end_token=end_key, to_key=end_key,
) )
room_key, _ = keys
end_key = "s" + room_key.split('-')[-1]
loaded_recents = sync_config.filter_collection.filter_room_timeline(events) loaded_recents = sync_config.filter_collection.filter_room_timeline(events)
loaded_recents = yield self._filter_events_for_client( loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(), sync_config.user.to_string(),
@ -710,8 +644,10 @@ class SyncHandler(BaseHandler):
) )
loaded_recents.extend(recents) loaded_recents.extend(recents)
recents = loaded_recents recents = loaded_recents
if len(events) <= load_limit: if len(events) <= load_limit:
limited = False limited = False
break
max_repeat -= 1 max_repeat -= 1
if len(recents) > timeline_limit: if len(recents) > timeline_limit:
@ -724,7 +660,9 @@ class SyncHandler(BaseHandler):
) )
defer.returnValue(TimelineBatch( defer.returnValue(TimelineBatch(
events=recents, prev_batch=prev_batch_token, limited=limited events=recents,
prev_batch=prev_batch_token,
limited=limited or newly_joined_room
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -732,25 +670,12 @@ class SyncHandler(BaseHandler):
since_token, now_token, since_token, now_token,
ephemeral_by_room, tags_by_room, ephemeral_by_room, tags_by_room,
account_data_by_room, account_data_by_room,
all_ephemeral_by_room): all_ephemeral_by_room,
""" Get the incremental delta needed to bring the client up to date for batch, full_state=False):
the room. Gives the client the most recent events and the changes to if full_state:
state. state = yield self.get_state_at(room_id, now_token)
Returns:
A Deferred JoinedSyncResult
"""
logger.debug("Doing incremental sync for room %s between %s and %s",
room_id, since_token, now_token)
# TODO(mjark): Check for redactions we might have missed. elif batch.limited:
batch = yield self.load_filtered_recents(
room_id, sync_config, now_token, since_token,
)
logger.debug("Recents %r", batch)
if batch.limited:
current_state = yield self.get_state_at(room_id, now_token) current_state = yield self.get_state_at(room_id, now_token)
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
@ -772,17 +697,6 @@ class SyncHandler(BaseHandler):
if just_joined: if just_joined:
state = yield self.get_state_at(room_id, now_token) state = yield self.get_state_at(room_id, now_token)
notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config, all_ephemeral_by_room
)
unread_notifications = {}
if notifs is not None:
unread_notifications["notification_count"] = len(notifs)
unread_notifications["highlight_count"] = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
state = { state = {
(e.type, e.state_key): e (e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values()) for e in sync_config.filter_collection.filter_room_state(state.values())
@ -800,6 +714,7 @@ class SyncHandler(BaseHandler):
ephemeral_by_room.get(room_id, []) ephemeral_by_room.get(room_id, [])
) )
unread_notifications = {}
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
@ -809,41 +724,55 @@ class SyncHandler(BaseHandler):
unread_notifications=unread_notifications, unread_notifications=unread_notifications,
) )
if room_sync:
notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config, all_ephemeral_by_room
)
if notifs is not None:
unread_notifications["notification_count"] = len(notifs)
unread_notifications["highlight_count"] = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
logger.debug("Room sync: %r", room_sync) logger.debug("Room sync: %r", room_sync)
defer.returnValue(room_sync) defer.returnValue(room_sync)
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_for_archived_room(self, sync_config, leave_event, def incremental_sync_for_archived_room(self, sync_config, room_id, leave_event_id,
since_token, tags_by_room, since_token, tags_by_room,
account_data_by_room): account_data_by_room, full_state,
leave_token=None):
""" Get the incremental delta needed to bring the client up to date for """ Get the incremental delta needed to bring the client up to date for
the archived room. the archived room.
Returns: Returns:
A Deferred ArchivedSyncResult A Deferred ArchivedSyncResult
""" """
if not leave_token:
stream_token = yield self.store.get_stream_token_for_event( stream_token = yield self.store.get_stream_token_for_event(
leave_event.event_id leave_event_id
) )
leave_token = since_token.copy_and_replace("room_key", stream_token) leave_token = since_token.copy_and_replace("room_key", stream_token)
if since_token.is_after(leave_token): if since_token and since_token.is_after(leave_token):
defer.returnValue(None) defer.returnValue(None)
batch = yield self.load_filtered_recents( batch = yield self.load_filtered_recents(
leave_event.room_id, sync_config, leave_token, since_token, room_id, sync_config, leave_token, since_token,
) )
logger.debug("Recents %r", batch) logger.debug("Recents %r", batch)
state_events_at_leave = yield self.store.get_state_for_event( state_events_at_leave = yield self.store.get_state_for_event(
leave_event.event_id leave_event_id
) )
if not full_state:
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
leave_event.room_id, stream_position=since_token room_id, stream_position=since_token
) )
state_events_delta = yield self.compute_state_delta( state_events_delta = yield self.compute_state_delta(
@ -851,6 +780,8 @@ class SyncHandler(BaseHandler):
previous_state=state_at_previous_sync, previous_state=state_at_previous_sync,
current_state=state_events_at_leave, current_state=state_events_at_leave,
) )
else:
state_events_delta = state_events_at_leave
state_events_delta = { state_events_delta = {
(e.type, e.state_key): e (e.type, e.state_key): e
@ -860,7 +791,7 @@ class SyncHandler(BaseHandler):
} }
account_data = self.account_data_for_room( account_data = self.account_data_for_room(
leave_event.room_id, tags_by_room, account_data_by_room room_id, tags_by_room, account_data_by_room
) )
account_data = sync_config.filter_collection.filter_room_account_data( account_data = sync_config.filter_collection.filter_room_account_data(
@ -868,7 +799,7 @@ class SyncHandler(BaseHandler):
) )
room_sync = ArchivedSyncResult( room_sync = ArchivedSyncResult(
room_id=leave_event.room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=state_events_delta, state=state_events_delta,
account_data=account_data, account_data=account_data,

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer from twisted.internet import defer
import ujson as json import ujson as json
@ -23,6 +24,14 @@ logger = logging.getLogger(__name__)
class AccountDataStore(SQLBaseStore): class AccountDataStore(SQLBaseStore):
def __init__(self, hs):
super(AccountDataStore, self).__init__(hs)
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache",
self._account_data_id_gen.get_max_token(None),
max_size=10000,
)
def get_account_data_for_user(self, user_id): def get_account_data_for_user(self, user_id):
"""Get all the client account_data for a user. """Get all the client account_data for a user.
@ -83,7 +92,7 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_room", get_account_data_for_room_txn "get_account_data_for_room", get_account_data_for_room_txn
) )
def get_updated_account_data_for_user(self, user_id, stream_id): def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None):
"""Get all the client account_data for a that's changed. """Get all the client account_data for a that's changed.
Args: Args:
@ -120,6 +129,12 @@ class AccountDataStore(SQLBaseStore):
return (global_account_data, account_data_by_room) return (global_account_data, account_data_by_room)
changed = self._account_data_stream_cache.has_entity_changed(
user_id, int(stream_id)
)
if not changed:
return ({}, {})
return self.runInteraction( return self.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
) )
@ -186,6 +201,10 @@ class AccountDataStore(SQLBaseStore):
"content": content_json, "content": content_json,
} }
) )
txn.call_after(
self._account_data_stream_cache.entity_has_changed,
user_id, next_id,
)
self._update_max_stream_id(txn, next_id) self._update_max_stream_id(txn, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id: with (yield self._account_data_id_gen.get_next(self)) as next_id:

View file

@ -210,6 +210,12 @@ class EventsStore(SQLBaseStore):
for event, _ in events_and_contexts: for event, _ in events_and_contexts:
txn.call_after(self._invalidate_get_event_cache, event.event_id) txn.call_after(self._invalidate_get_event_cache, event.event_id)
if not backfilled:
txn.call_after(
self._events_stream_cache.entity_has_changed,
event.room_id, event.internal_metadata.stream_ordering,
)
depth_updates = {} depth_updates = {}
for event, _ in events_and_contexts: for event, _ in events_and_contexts:
if event.internal_metadata.is_outlier(): if event.internal_metadata.is_outlier():

View file

@ -16,12 +16,13 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
import simplejson as json import simplejson as json
class FilteringStore(SQLBaseStore): class FilteringStore(SQLBaseStore):
@defer.inlineCallbacks @cachedInlineCallbacks(num_args=2)
def get_user_filter(self, user_localpart, filter_id): def get_user_filter(self, user_localpart, filter_id):
def_json = yield self._simple_select_one_onecol( def_json = yield self._simple_select_one_onecol(
table="user_filters", table="user_filters",

View file

@ -15,11 +15,10 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
from synapse.util.caches import cache_counter, caches_by_name from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer from twisted.internet import defer
from blist import sorteddict
import logging import logging
import ujson as json import ujson as json
@ -31,8 +30,8 @@ class ReceiptsStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, hs):
super(ReceiptsStore, self).__init__(hs) super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = _RoomStreamChangeCache( self._receipts_stream_cache = StreamChangeCache(
self._receipts_id_gen.get_max_token(None) "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None)
) )
@cached(num_args=2) @cached(num_args=2)
@ -78,8 +77,8 @@ class ReceiptsStore(SQLBaseStore):
room_ids = set(room_ids) room_ids = set(room_ids)
if from_key: if from_key:
room_ids = yield self._receipts_stream_cache.get_rooms_changed( room_ids = yield self._receipts_stream_cache.get_entities_changed(
self, room_ids, from_key room_ids, from_key
) )
results = yield self._get_linearized_receipts_for_rooms( results = yield self._get_linearized_receipts_for_rooms(
@ -222,6 +221,11 @@ class ReceiptsStore(SQLBaseStore):
# FIXME: This shouldn't invalidate the whole cache # FIXME: This shouldn't invalidate the whole cache
txn.call_after(self.get_linearized_receipts_for_room.invalidate_all) txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
txn.call_after(
self._receipts_stream_cache.entity_has_changed,
room_id, stream_id
)
# We don't want to clobber receipts for more recent events, so we # We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts # have to compare orderings of existing receipts
sql = ( sql = (
@ -309,9 +313,6 @@ class ReceiptsStore(SQLBaseStore):
stream_id_manager = yield self._receipts_id_gen.get_next(self) stream_id_manager = yield self._receipts_id_gen.get_next(self)
with stream_id_manager as stream_id: with stream_id_manager as stream_id:
yield self._receipts_stream_cache.room_has_changed(
self, room_id, stream_id
)
have_persisted = yield self.runInteraction( have_persisted = yield self.runInteraction(
"insert_linearized_receipt", "insert_linearized_receipt",
self.insert_linearized_receipt_txn, self.insert_linearized_receipt_txn,
@ -370,63 +371,3 @@ class ReceiptsStore(SQLBaseStore):
"data": json.dumps(data), "data": json.dumps(data),
} }
) )
class _RoomStreamChangeCache(object):
"""Keeps track of the stream_id of the latest change in rooms.
Given a list of rooms and stream key, it will give a subset of rooms that
may have changed since that key. If the key is too old then the cache
will simply return all rooms.
"""
def __init__(self, current_key, size_of_cache=10000):
self._size_of_cache = size_of_cache
self._room_to_key = {}
self._cache = sorteddict()
self._earliest_key = current_key
self.name = "ReceiptsRoomChangeCache"
caches_by_name[self.name] = self._cache
@defer.inlineCallbacks
def get_rooms_changed(self, store, room_ids, key):
"""Returns subset of room ids that have had new receipts since the
given key. If the key is too old it will just return the given list.
"""
if key > (yield self._get_earliest_key(store)):
keys = self._cache.keys()
i = keys.bisect_right(key)
result = set(
self._cache[k] for k in keys[i:]
).intersection(room_ids)
cache_counter.inc_hits(self.name)
else:
result = room_ids
cache_counter.inc_misses(self.name)
defer.returnValue(result)
@defer.inlineCallbacks
def room_has_changed(self, store, room_id, key):
"""Informs the cache that the room has been changed at the given key.
"""
if key > (yield self._get_earliest_key(store)):
old_key = self._room_to_key.get(room_id, None)
if old_key:
key = max(key, old_key)
self._cache.pop(old_key, None)
self._cache[key] = room_id
while len(self._cache) > self._size_of_cache:
k, r = self._cache.popitem()
self._earliest_key = max(k, self._earliest_key)
self._room_to_key.pop(r, None)
@defer.inlineCallbacks
def _get_earliest_key(self, store):
if self._earliest_key is None:
self._earliest_key = yield store.get_max_receipt_stream_id()
self._earliest_key = int(self._earliest_key)
defer.returnValue(self._earliest_key)

View file

@ -0,0 +1,16 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX events_room_stream on events(room_id, stream_ordering);

View file

@ -37,6 +37,7 @@ from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -77,6 +78,12 @@ def upper_bound(token):
class StreamStore(SQLBaseStore): class StreamStore(SQLBaseStore):
def __init__(self, hs):
super(StreamStore, self).__init__(hs)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", self._stream_id_gen.get_max_token(None)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_appservice_room_stream(self, service, from_key, to_key, limit=0): def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
@ -157,6 +164,135 @@ class StreamStore(SQLBaseStore):
results = yield self.runInteraction("get_appservice_room_stream", f) results = yield self.runInteraction("get_appservice_room_stream", f)
defer.returnValue(results) defer.returnValue(results)
@defer.inlineCallbacks
def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0):
from_id = RoomStreamToken.parse_stream_token(from_key).stream
room_ids = yield self._events_stream_cache.get_entities_changed(
room_ids, from_id
)
if not room_ids:
defer.returnValue({})
results = {}
room_ids = list(room_ids)
for rm_ids in (room_ids[i:i+20] for i in xrange(0, len(room_ids), 20)):
res = yield defer.gatherResults([
self.get_room_events_stream_for_room(
room_id, from_key, to_key, limit
).addCallback(lambda r, rm: (rm, r), room_id)
for room_id in room_ids
])
results.update(dict(res))
defer.returnValue(results)
@defer.inlineCallbacks
def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0):
if from_key is not None:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
else:
from_id = None
to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key:
defer.returnValue(([], from_key))
if from_id:
has_changed = yield self._events_stream_cache.has_entity_changed(
room_id, from_id
)
if not has_changed:
defer.returnValue(([], from_key))
def f(txn):
if from_id is not None:
sql = (
"SELECT event_id, stream_ordering FROM events WHERE"
" room_id = ?"
" AND not outlier"
" AND stream_ordering > ? AND stream_ordering <= ?"
" ORDER BY stream_ordering DESC LIMIT ?"
)
txn.execute(sql, (room_id, from_id, to_id, limit))
else:
sql = (
"SELECT event_id, stream_ordering FROM events WHERE"
" room_id = ?"
" AND not outlier"
" AND stream_ordering <= ?"
" ORDER BY stream_ordering DESC LIMIT ?"
)
txn.execute(sql, (room_id, to_id, limit))
rows = self.cursor_to_dict(txn)
ret = self._get_events_txn(
txn,
[r["event_id"] for r in rows],
get_prev_content=True
)
self._set_before_and_after(ret, rows, topo_order=False)
ret.reverse()
if rows:
key = "s%d" % min(r["stream_ordering"] for r in rows)
else:
# Assume we didn't get anything because there was nothing to
# get.
key = from_key
return ret, key
res = yield self.runInteraction("get_room_events_stream_for_room", f)
defer.returnValue(res)
def get_room_changes_for_user(self, user_id, from_key, to_key):
if from_key is not None:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
else:
from_id = None
to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key:
return defer.succeed([])
def f(txn):
if from_id is not None:
sql = (
"SELECT m.event_id, stream_ordering FROM events AS e,"
" room_memberships AS m"
" WHERE e.event_id = m.event_id"
" AND m.user_id = ?"
" AND e.stream_ordering > ? AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
)
txn.execute(sql, (user_id, from_id, to_id,))
else:
sql = (
"SELECT m.event_id, stream_ordering FROM events AS e,"
" room_memberships AS m"
" WHERE e.event_id = m.event_id"
" AND m.user_id = ?"
" AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
)
txn.execute(sql, (user_id, to_id,))
rows = self.cursor_to_dict(txn)
ret = self._get_events_txn(
txn,
[r["event_id"] for r in rows],
get_prev_content=True
)
return ret
return self.runInteraction("get_room_changes_for_user", f)
@log_function @log_function
def get_room_events_stream( def get_room_events_stream(
self, self,
@ -174,7 +310,8 @@ class StreamStore(SQLBaseStore):
"SELECT c.room_id FROM history_visibility AS h" "SELECT c.room_id FROM history_visibility AS h"
" INNER JOIN current_state_events AS c" " INNER JOIN current_state_events AS c"
" ON h.event_id = c.event_id" " ON h.event_id = c.event_id"
" WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % ( " WHERE c.room_id IN (%s)"
" AND h.history_visibility = 'world_readable'" % (
",".join(map(lambda _: "?", room_ids)) ",".join(map(lambda _: "?", room_ids))
) )
) )
@ -434,6 +571,18 @@ class StreamStore(SQLBaseStore):
row["topological_ordering"], row["stream_ordering"],) row["topological_ordering"], row["stream_ordering"],)
) )
def get_max_topological_token_for_stream_and_room(self, room_id, stream_key):
sql = (
"SELECT max(topological_ordering) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
return self._execute(
"get_max_topological_token_for_stream_and_room", None,
sql, room_id, stream_key,
).addCallback(
lambda r: r[0][0] if r else 0
)
def _get_max_topological_txn(self, txn): def _get_max_topological_txn(self, txn):
txn.execute( txn.execute(
"SELECT MAX(topological_ordering) FROM events" "SELECT MAX(topological_ordering) FROM events"
@ -445,10 +594,13 @@ class StreamStore(SQLBaseStore):
return rows[0][0] if rows else 0 return rows[0][0] if rows else 0
@staticmethod @staticmethod
def _set_before_and_after(events, rows): def _set_before_and_after(events, rows, topo_order=True):
for event, row in zip(events, rows): for event, row in zip(events, rows):
stream = row["stream_ordering"] stream = row["stream_ordering"]
if topo_order:
topo = event.depth topo = event.depth
else:
topo = None
internal = event.internal_metadata internal = event.internal_metadata
internal.before = str(RoomStreamToken(topo, stream - 1)) internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(RoomStreamToken(topo, stream)) internal.after = str(RoomStreamToken(topo, stream))

View file

@ -24,7 +24,6 @@ logger = logging.getLogger(__name__)
class TagsStore(SQLBaseStore): class TagsStore(SQLBaseStore):
def get_max_account_data_stream_id(self): def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream """Get the current max stream id for the private user data stream
@ -80,6 +79,12 @@ class TagsStore(SQLBaseStore):
room_ids = [row[0] for row in txn.fetchall()] room_ids = [row[0] for row in txn.fetchall()]
return room_ids return room_ids
changed = self._account_data_stream_cache.has_entity_changed(
user_id, int(stream_id)
)
if not changed:
defer.returnValue({})
room_ids = yield self.runInteraction( room_ids = yield self.runInteraction(
"get_updated_tags", get_updated_tags_txn "get_updated_tags", get_updated_tags_txn
) )
@ -177,6 +182,11 @@ class TagsStore(SQLBaseStore):
next_id(int): The the revision to advance to. next_id(int): The the revision to advance to.
""" """
txn.call_after(
self._account_data_stream_cache.entity_has_changed,
user_id, next_id
)
update_max_id_sql = ( update_max_id_sql = (
"UPDATE account_data_max_stream_id" "UPDATE account_data_max_stream_id"
" SET stream_id = ?" " SET stream_id = ?"

View file

@ -37,7 +37,7 @@ class LruCache(object):
""" """
def __init__(self, max_size, keylen=1, cache_type=dict): def __init__(self, max_size, keylen=1, cache_type=dict):
cache = cache_type() cache = cache_type()
self.size = 0 self.cache = cache # Used for introspection.
list_root = [] list_root = []
list_root[:] = [list_root, list_root, None, None] list_root[:] = [list_root, list_root, None, None]
@ -60,7 +60,6 @@ class LruCache(object):
prev_node[NEXT] = node prev_node[NEXT] = node
next_node[PREV] = node next_node[PREV] = node
cache[key] = node cache[key] = node
self.size += 1
def move_node_to_front(node): def move_node_to_front(node):
prev_node = node[PREV] prev_node = node[PREV]
@ -79,7 +78,6 @@ class LruCache(object):
next_node = node[NEXT] next_node = node[NEXT]
prev_node[NEXT] = next_node prev_node[NEXT] = next_node
next_node[PREV] = prev_node next_node[PREV] = prev_node
self.size -= 1
@synchronized @synchronized
def cache_get(key, default=None): def cache_get(key, default=None):
@ -98,7 +96,7 @@ class LruCache(object):
node[VALUE] = value node[VALUE] = value
else: else:
add_node(key, value) add_node(key, value)
if self.size > max_size: if len(cache) > max_size:
todelete = list_root[PREV] todelete = list_root[PREV]
delete_node(todelete) delete_node(todelete)
cache.pop(todelete[KEY], None) cache.pop(todelete[KEY], None)
@ -110,7 +108,7 @@ class LruCache(object):
return node[VALUE] return node[VALUE]
else: else:
add_node(key, value) add_node(key, value)
if self.size > max_size: if len(cache) > max_size:
todelete = list_root[PREV] todelete = list_root[PREV]
delete_node(todelete) delete_node(todelete)
cache.pop(todelete[KEY], None) cache.pop(todelete[KEY], None)
@ -145,7 +143,7 @@ class LruCache(object):
@synchronized @synchronized
def cache_len(): def cache_len():
return self.size return len(cache)
@synchronized @synchronized
def cache_contains(key): def cache_contains(key):

View file

@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.caches import cache_counter, caches_by_name
from blist import sorteddict
import logging
logger = logging.getLogger(__name__)
class StreamChangeCache(object):
"""Keeps track of the stream positions of the latest change in a set of entities.
Typically the entity will be a room or user id.
Given a list of entities and a stream position, it will give a subset of
entities that may have changed since that position. If position key is too
old then the cache will simply return all given entities.
"""
def __init__(self, name, current_stream_pos, max_size=10000):
self._max_size = max_size
self._entity_to_key = {}
self._cache = sorteddict()
self._earliest_known_stream_pos = current_stream_pos
self.name = name
caches_by_name[self.name] = self._cache
def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos
"""
assert type(stream_pos) is int
if stream_pos < self._earliest_known_stream_pos:
cache_counter.inc_misses(self.name)
return True
if stream_pos == self._earliest_known_stream_pos:
# If the same as the earliest key, assume nothing has changed.
cache_counter.inc_hits(self.name)
return False
latest_entity_change_pos = self._entity_to_key.get(entity, None)
if latest_entity_change_pos is None:
cache_counter.inc_misses(self.name)
return True
if stream_pos < latest_entity_change_pos:
cache_counter.inc_misses(self.name)
return True
cache_counter.inc_hits(self.name)
return False
def get_entities_changed(self, entities, stream_pos):
"""Returns subset of entities that have had new things since the
given position. If the position is too old it will just return the given list.
"""
assert type(stream_pos) is int
if stream_pos >= self._earliest_known_stream_pos:
keys = self._cache.keys()
i = keys.bisect_right(stream_pos)
result = set(
self._cache[k] for k in keys[i:]
).intersection(entities)
cache_counter.inc_hits(self.name)
else:
result = entities
cache_counter.inc_misses(self.name)
return result
def entity_has_changed(self, entity, stream_pos):
"""Informs the cache that the entity has been changed at the given
position.
"""
assert type(stream_pos) is int
if stream_pos > self._earliest_known_stream_pos:
old_pos = self._entity_to_key.get(entity, None)
if old_pos:
stream_pos = max(stream_pos, old_pos)
self._cache.pop(old_pos, None)
self._cache[stream_pos] = entity
self._entity_to_key[entity] = stream_pos
while len(self._cache) > self._max_size:
k, r = self._cache.popitem()
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
self._entity_to_key.pop(r, None)

View file

@ -8,6 +8,7 @@ class TreeCache(object):
Keys must be tuples. Keys must be tuples.
""" """
def __init__(self): def __init__(self):
self.size = 0
self.root = {} self.root = {}
def __setitem__(self, key, value): def __setitem__(self, key, value):
@ -20,7 +21,8 @@ class TreeCache(object):
node = self.root node = self.root
for k in key[:-1]: for k in key[:-1]:
node = node.setdefault(k, {}) node = node.setdefault(k, {})
node[key[-1]] = value node[key[-1]] = _Entry(value)
self.size += 1
def get(self, key, default=None): def get(self, key, default=None):
node = self.root node = self.root
@ -28,9 +30,10 @@ class TreeCache(object):
node = node.get(k, None) node = node.get(k, None)
if node is None: if node is None:
return default return default
return node.get(key[-1], default) return node.get(key[-1], _Entry(default)).value
def clear(self): def clear(self):
self.size = 0
self.root = {} self.root = {}
def pop(self, key, default=None): def pop(self, key, default=None):
@ -57,4 +60,33 @@ class TreeCache(object):
break break
node_and_keys[i+1][0].pop(k) node_and_keys[i+1][0].pop(k)
popped, cnt = _strip_and_count_entires(popped)
self.size -= cnt
return popped return popped
def __len__(self):
return self.size
class _Entry(object):
__slots__ = ["value"]
def __init__(self, value):
self.value = value
def _strip_and_count_entires(d):
"""Takes an _Entry or dict with leaves of _Entry's, and either returns the
value or a dictionary with _Entry's replaced by their values.
Also returns the count of _Entry's
"""
if isinstance(d, dict):
cnt = 0
for key, value in d.items():
v, n = _strip_and_count_entires(value)
d[key] = v
cnt += n
return d, cnt
else:
return d.value, 1

View file

@ -382,19 +382,20 @@ class FilteringTestCase(unittest.TestCase):
"types": ["m.*"] "types": ["m.*"]
} }
} }
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart + "2",
user_filter=user_filter_json, user_filter=user_filter_json,
) )
event = MockEvent( event = MockEvent(
event_id="$asdasd:localhost",
sender="@foo:bar", sender="@foo:bar",
type="custom.avatar.3d.crazy", type="custom.avatar.3d.crazy",
) )
events = [event] events = [event]
user_filter = yield self.filtering.get_user_filter( user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart + "2",
filter_id=filter_id, filter_id=filter_id,
) )

View file

@ -1044,13 +1044,6 @@ class RoomMessageListTestCase(RestTestCase):
self.assertTrue("chunk" in response) self.assertTrue("chunk" in response)
self.assertTrue("end" in response) self.assertTrue("end" in response)
@defer.inlineCallbacks
def test_stream_token_is_rejected_for_back_pagination(self):
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=s0_0_0_0_0&dir=b" %
self.room_id)
self.assertEquals(400, code)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_stream_token_is_accepted_for_fwd_pagianation(self): def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0" token = "s0_0_0_0_0"

View file

@ -19,6 +19,7 @@ from .. import unittest
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache from synapse.util.caches.treecache import TreeCache
class LruCacheTestCase(unittest.TestCase): class LruCacheTestCase(unittest.TestCase):
def test_get_set(self): def test_get_set(self):
@ -72,3 +73,9 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get(("vehicles", "car")), "vroom") self.assertEquals(cache.get(("vehicles", "car")), "vroom")
self.assertEquals(cache.get(("vehicles", "train")), "chuff") self.assertEquals(cache.get(("vehicles", "train")), "chuff")
# Man from del_multi say "Yes". # Man from del_multi say "Yes".
def test_clear(self):
cache = LruCache(1)
cache["key"] = 1
cache.clear()
self.assertEquals(len(cache), 0)

View file

@ -25,6 +25,7 @@ class TreeCacheTestCase(unittest.TestCase):
cache[("b",)] = "B" cache[("b",)] = "B"
self.assertEquals(cache.get(("a",)), "A") self.assertEquals(cache.get(("a",)), "A")
self.assertEquals(cache.get(("b",)), "B") self.assertEquals(cache.get(("b",)), "B")
self.assertEquals(len(cache), 2)
def test_pop_onelevel(self): def test_pop_onelevel(self):
cache = TreeCache() cache = TreeCache()
@ -33,6 +34,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEquals(cache.pop(("a",)), "A") self.assertEquals(cache.pop(("a",)), "A")
self.assertEquals(cache.pop(("a",)), None) self.assertEquals(cache.pop(("a",)), None)
self.assertEquals(cache.get(("b",)), "B") self.assertEquals(cache.get(("b",)), "B")
self.assertEquals(len(cache), 1)
def test_get_set_twolevel(self): def test_get_set_twolevel(self):
cache = TreeCache() cache = TreeCache()
@ -42,6 +44,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get(("a", "a")), "AA") self.assertEquals(cache.get(("a", "a")), "AA")
self.assertEquals(cache.get(("a", "b")), "AB") self.assertEquals(cache.get(("a", "b")), "AB")
self.assertEquals(cache.get(("b", "a")), "BA") self.assertEquals(cache.get(("b", "a")), "BA")
self.assertEquals(len(cache), 3)
def test_pop_twolevel(self): def test_pop_twolevel(self):
cache = TreeCache() cache = TreeCache()
@ -53,6 +56,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get(("a", "b")), "AB") self.assertEquals(cache.get(("a", "b")), "AB")
self.assertEquals(cache.pop(("b", "a")), "BA") self.assertEquals(cache.pop(("b", "a")), "BA")
self.assertEquals(cache.pop(("b", "a")), None) self.assertEquals(cache.pop(("b", "a")), None)
self.assertEquals(len(cache), 1)
def test_pop_mixedlevel(self): def test_pop_mixedlevel(self):
cache = TreeCache() cache = TreeCache()
@ -64,3 +68,11 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get(("a", "a")), None) self.assertEquals(cache.get(("a", "a")), None)
self.assertEquals(cache.get(("a", "b")), None) self.assertEquals(cache.get(("a", "b")), None)
self.assertEquals(cache.get(("b", "a")), "BA") self.assertEquals(cache.get(("b", "a")), "BA")
self.assertEquals(len(cache), 1)
def test_clear(self):
cache = TreeCache()
cache[("a",)] = "A"
cache[("b",)] = "B"
cache.clear()
self.assertEquals(len(cache), 0)