Construct a source-specific 'SourcePaginationConfig' to pass into get_pagination_rows; meaning each source doesn't have to care about its own name any more

This commit is contained in:
Paul "LeoNerd" Evans 2014-10-29 15:57:23 +00:00
parent c5a25f610a
commit d6bcffa929
6 changed files with 43 additions and 40 deletions

View file

@ -115,8 +115,12 @@ class MessageHandler(BaseHandler):
user = self.hs.parse_userid(user_id)
events, next_token = yield data_source.get_pagination_rows(
user, pagin_config, room_id
events, next_key = yield data_source.get_pagination_rows(
user, pagin_config.get_source_config("room"), room_id
)
next_token = pagin_config.from_token.copy_and_replace(
"room_key", next_key
)
chunk = {
@ -271,7 +275,7 @@ class MessageHandler(BaseHandler):
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows(
user, pagination_config, None
user, pagination_config.get_source_config("presence"), None
)
public_rooms = yield self.store.get_rooms(is_public=True)

View file

@ -823,15 +823,12 @@ class PresenceEventSource(object):
def get_pagination_rows(self, user, pagination_config, key):
# TODO (erikj): Does this make sense? Ordering?
from_token = pagination_config.from_token
to_token = pagination_config.to_token
observer_user = user
from_key = int(from_token.presence_key)
from_key = int(pagination_config.from_key)
if to_token:
to_key = int(to_token.presence_key)
if pagination_config.to_key:
to_key = int(pagination_config.to_key)
else:
to_key = -1
@ -855,21 +852,9 @@ class PresenceEventSource(object):
earliest_serial = max([x[1].serial for x in updates])
data = [x[1].make_event(user=x[0], clock=clock) for x in updates]
if to_token:
next_token = to_token
else:
next_token = from_token
next_token = next_token.copy_and_replace(
"presence_key", earliest_serial
)
defer.returnValue((data, next_token))
defer.returnValue((data, earliest_serial))
else:
if not to_token:
to_token = from_token.copy_and_replace(
"presence_key", 0
)
defer.returnValue(([], to_token))
defer.returnValue(([], 0))
class UserPresenceCache(object):

View file

@ -612,23 +612,14 @@ class RoomEventSource(object):
return self.store.get_room_events_max_id()
@defer.inlineCallbacks
def get_pagination_rows(self, user, pagination_config, key):
from_token = pagination_config.from_token
to_token = pagination_config.to_token
limit = pagination_config.limit
direction = pagination_config.direction
to_key = to_token.room_key if to_token else None
def get_pagination_rows(self, user, config, key):
events, next_key = yield self.store.paginate_room_events(
room_id=key,
from_key=from_token.room_key,
to_key=to_key,
direction=direction,
limit=limit,
from_key=config.from_key,
to_key=config.to_key,
direction=config.direction,
limit=config.limit,
with_feedback=True
)
next_token = from_token.copy_and_replace("room_key", next_key)
defer.returnValue((events, next_token))
defer.returnValue((events, next_key))

View file

@ -158,4 +158,4 @@ class TypingNotificationEventSource(object):
return 0
def get_pagination_rows(self, user, pagination_config, key):
return ([], pagination_config.from_token)
return ([], pagination_config.from_key)

View file

@ -22,6 +22,19 @@ import logging
logger = logging.getLogger(__name__)
class SourcePaginationConfig(object):
"""A configuration object which stores pagination parameters for a
specific event source."""
def __init__(self, from_key=None, to_key=None, direction='f',
limit=0):
self.from_key = from_key
self.to_key = to_key
self.direction = 'f' if direction == 'f' else 'b'
self.limit = int(limit)
class PaginationConfig(object):
"""A configuration object which stores pagination parameters."""
@ -82,3 +95,13 @@ class PaginationConfig(object):
"<PaginationConfig from_tok=%s, to_tok=%s, "
"direction=%s, limit=%s>"
) % (self.from_token, self.to_token, self.direction, self.limit)
def get_source_config(self, source_name):
keyname = "%s_key" % source_name
return SourcePaginationConfig(
from_key=getattr(self.from_token, keyname),
to_key=getattr(self.to_token, keyname) if self.to_token else None,
direction=self.direction,
limit=self.limit,
)

View file

@ -35,7 +35,7 @@ class NullSource(object):
return defer.succeed(0)
def get_pagination_rows(self, user, pagination_config, key):
return defer.succeed(([], pagination_config.from_token))
return defer.succeed(([], pagination_config.from_key))
class EventSources(object):