abandoned branch, pushing for reference

This commit is contained in:
Neil Johnson 2018-09-26 15:16:02 +01:00
parent ad53a5497d
commit 0d14119d42
8 changed files with 231 additions and 133 deletions

View file

@ -822,6 +822,7 @@ class Auth(object):
return return
# Else if there is no room in the MAU bucket, bail # Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count() current_mau = yield self.store.get_monthly_active_count()
print ("auth check, current_mau %d" % current_mau)
if current_mau >= self.hs.config.max_mau_value: if current_mau >= self.hs.config.max_mau_value:
raise ResourceLimitError( raise ResourceLimitError(
403, "Monthly Active User Limit Exceeded", 403, "Monthly Active User Limit Exceeded",

View file

@ -295,7 +295,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# Necessary due to auth checks prior to the threepid being # Necessary due to auth checks prior to the threepid being
# written to the db # written to the db
if is_threepid_reserved(self.hs.config, threepid): if is_threepid_reserved(self.hs.config, threepid):
yield self.store.upsert_monthly_active_user(user_id) self.store.upsert_monthly_active_user(user_id)
if session[LoginType.EMAIL_IDENTITY]: if session[LoginType.EMAIL_IDENTITY]:
logger.debug("Binding emails %s to %s" % ( logger.debug("Binding emails %s to %s" % (

View file

@ -416,7 +416,7 @@ class RegisterRestServlet(RestServlet):
# Necessary due to auth checks prior to the threepid being # Necessary due to auth checks prior to the threepid being
# written to the db # written to the db
if is_threepid_reserved(self.hs.config, threepid): if is_threepid_reserved(self.hs.config, threepid):
yield self.store.upsert_monthly_active_user(registered_user_id) self.store.upsert_monthly_active_user(registered_user_id)
# remember that we've now registered that user account, and with # remember that we've now registered that user account, and with
# what user ID (since the user may not have specified) # what user ID (since the user may not have specified)

View file

@ -14,26 +14,41 @@
# limitations under the License. # limitations under the License.
import logging import logging
from six import iteritems
from twisted.internet import defer from twisted.internet import defer
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.metrics.background_process_metrics import run_as_background_process
from . import background_updates
from ._base import SQLBaseStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Number of msec of granularity to store the monthly_active_user timestamp # Number of msec of granularity to store the monthly_active_user timestamp
# This means it is not necessary to update the table on every request # This means it is not necessary to update the table on every request
LAST_SEEN_GRANULARITY = 60 * 60 * 1000 LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersStore(SQLBaseStore): class MonthlyActiveUsersStore(background_updates.BackgroundUpdateStore):
def __init__(self, dbconn, hs): def __init__(self, dbconn, hs):
super(MonthlyActiveUsersStore, self).__init__(None, hs) super(MonthlyActiveUsersStore, self).__init__(None, hs)
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.hs = hs self.hs = hs
self.reserved_users = () self.reserved_users = ()
# user_id:timestamp
self._batch_row_update_mau = {}
self._mau_looper = self._clock.looping_call(
self._update_monthly_active_users_batch, 5 * 1000
)
self.hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", self._update_monthly_active_users_batch
)
@defer.inlineCallbacks @defer.inlineCallbacks
def initialise_reserved_users(self, threepids): def initialise_reserved_users(self, threepids):
store = self.hs.get_datastore() store = self.hs.get_datastore()
@ -127,23 +142,37 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# is racy. # is racy.
# Have resolved to invalidate the whole cache for now and do # Have resolved to invalidate the whole cache for now and do
# something about it if and when the perf becomes significant # something about it if and when the perf becomes significant
self.user_last_seen_monthly_active.invalidate_all() # self.user_last_seen_monthly_active.invalidate_all()
self.get_monthly_active_count.invalidate_all() # self.get_monthly_active_count.invalidate_all()
@cached(num_args=0) #@cached(num_args=0)
def get_monthly_active_count(self): def get_monthly_active_count(self):
"""Generates current count of monthly active users """Generates current count of monthly active users
Returns: Returns:
Defered[int]: Number of current monthly active users Defered[int]: Number of current monthly active users
""" """
# in_mem_new_users = 0
# for user_id, timestamp in iteritems(self._batch_row_update_mau):
# mau_member_ts = self.user_last_seen_monthly_active(user_id)
# if mau_member_ts is None:
# in_mem_new_users = in_mem_new_users + 1
# Ideally I'd check in self._batch_row_update_mau adnd any outstanding
# new users to the total, but I can't because the only way to determine
# if the user is new is to call user_last_seen_monthly_active which itself
# checks in self._batch_row_update_mau and therefore will always answer
# that the user is pre-existing.
def _count_users(txn): def _count_users(txn):
sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
txn.execute(sql) txn.execute(sql)
count, = txn.fetchone() count, = txn.fetchone()
print "count is %d" % count
return count return count
#return defer.returnValue(self.runInteraction("count_users", _count_users, in_mem_new_users))
return self.runInteraction("count_users", _count_users) return self.runInteraction("count_users", _count_users)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -163,31 +192,21 @@ class MonthlyActiveUsersStore(SQLBaseStore):
count = count + 1 count = count + 1
defer.returnValue(count) defer.returnValue(count)
@defer.inlineCallbacks
def upsert_monthly_active_user(self, user_id): def upsert_monthly_active_user(self, user_id):
""" """
Updates or inserts monthly active user member Adds request to updates or insert monthly active user member
Arguments: Arguments:
user_id (str): user to add/update user_id (str): user to add/update
Deferred[bool]: True if a new entry was created, False if an
existing one was updated.
""" """
is_insert = yield self._simple_upsert( logger.error('upsert_monthly_active_user type of user_id is %s' % type(user_id))
desc="upsert_monthly_active_user", timestamp = int(self._clock.time_msec())
table="monthly_active_users", self._batch_row_update_mau[user_id] = timestamp
keyvalues={ self.user_last_seen_monthly_active.prefill(user_id, timestamp)
"user_id": user_id,
},
values={
"timestamp": int(self._clock.time_msec()),
},
lock=False,
)
if is_insert:
self.user_last_seen_monthly_active.invalidate((user_id,))
self.get_monthly_active_count.invalidate(())
@cached(num_args=1) # self.user_last_seen_monthly_active.invalidate((user_id,))
# self.get_monthly_active_count.invalidate(())
#@cached(num_args=1)
def user_last_seen_monthly_active(self, user_id): def user_last_seen_monthly_active(self, user_id):
""" """
Checks if a given user is part of the monthly active user group Checks if a given user is part of the monthly active user group
@ -197,6 +216,10 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Deferred[int] : timestamp since last seen, None if never seen Deferred[int] : timestamp since last seen, None if never seen
""" """
# Need to check in memory batch queue
# last_seen = self._batch_row_update_mau.get(user_id)
# if last_seen:
# return defer.returnValue(last_seen)
return(self._simple_select_one_onecol( return(self._simple_select_one_onecol(
table="monthly_active_users", table="monthly_active_users",
@ -237,6 +260,54 @@ class MonthlyActiveUsersStore(SQLBaseStore):
if last_seen_timestamp is None: if last_seen_timestamp is None:
count = yield self.get_monthly_active_count() count = yield self.get_monthly_active_count()
if count < self.hs.config.max_mau_value: if count < self.hs.config.max_mau_value:
yield self.upsert_monthly_active_user(user_id) self.upsert_monthly_active_user(user_id)
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
yield self.upsert_monthly_active_user(user_id) self.upsert_monthly_active_user(user_id)
def _update_monthly_active_users_batch(self):
# If the DB pool has already terminated, don't try updating
if not self.hs.get_db_pool().running:
return
def update():
to_update = self._batch_row_update_mau
self._batch_row_update_mau = {}
return self.runInteraction(
"_update_monthly_active_users_batch",
self._update_monthly_active_users_batch_txn,
to_update,
)
#self.get_monthly_active_count.invalidate(())
return run_as_background_process(
"update_monthly_active_users", update,
)
def _update_monthly_active_users_batch_txn(self, txn, to_update):
self.database_engine.lock_table(txn, "monthly_active_users")
logger.error('to_update %r' % to_update)
for user_id, timestamp in iteritems(to_update):
logger.error("upserting %s" % user_id)
print "upserting %s" % user_id
try:
self._simple_upsert_txn(
txn,
table="monthly_active_users",
keyvalues={
"user_id": user_id,
},
values={
"timestamp": timestamp,
},
lock=False,
)
# Not sure if I need to do this here since the result is already
# prefilled in upsert_monthly_active_user though seems safer to
# do so
#self.user_last_seen_monthly_active.invalidate((user_id,))
except Exception as e:
# Failed to upsert, log and continue
logger.error("Failed to insert mau user %s: %r", user_id, e)
# if len(to_update) > 0:
# self.get_monthly_active_count.invalidate(())

View file

@ -22,9 +22,16 @@ from synapse.types import UserID
import tests.unittest import tests.unittest
import tests.utils import tests.utils
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
from tests.unittest import HomeserverTestCase
from tests.server import (
ThreadedMemoryReactorClock,
)
ONE_HOUR = 60 * 60 * 1000
class SyncTestCase(tests.unittest.TestCase): class SyncTestCase(HomeserverTestCase):
""" Tests Sync Handler. """ """ Tests Sync Handler. """
@defer.inlineCallbacks @defer.inlineCallbacks
@ -32,6 +39,7 @@ class SyncTestCase(tests.unittest.TestCase):
self.hs = yield setup_test_homeserver(self.addCleanup) self.hs = yield setup_test_homeserver(self.addCleanup)
self.sync_handler = SyncHandler(self.hs) self.sync_handler = SyncHandler(self.hs)
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.reactor = ThreadedMemoryReactorClock()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_wait_for_sync_for_user_auth_blocking(self): def test_wait_for_sync_for_user_auth_blocking(self):
@ -44,7 +52,7 @@ class SyncTestCase(tests.unittest.TestCase):
self.hs.config.max_mau_value = 1 self.hs.config.max_mau_value = 1
# Check that the happy case does not throw errors # Check that the happy case does not throw errors
yield self.store.upsert_monthly_active_user(user_id1) self.store.upsert_monthly_active_user(user_id1)
yield self.sync_handler.wait_for_sync_for_user(sync_config) yield self.sync_handler.wait_for_sync_for_user(sync_config)
# Test that global lock works # Test that global lock works
@ -56,7 +64,11 @@ class SyncTestCase(tests.unittest.TestCase):
self.hs.config.hs_disabled = False self.hs.config.hs_disabled = False
sync_config = self._generate_sync_config(user_id2) sync_config = self._generate_sync_config(user_id2)
print 'pre wait'
self.reactor.advance(ONE_HOUR)
self.pump()
print 'post wait'
with self.assertRaises(ResourceLimitError) as e: with self.assertRaises(ResourceLimitError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config) yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)

View file

@ -36,104 +36,104 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def prepare(self, hs, reactor, clock): def prepare(self, hs, reactor, clock):
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
def test_insert_new_client_ip(self): # def test_insert_new_client_ip(self):
self.reactor.advance(12345678) # self.reactor.advance(12345678)
#
user_id = "@user:id" # user_id = "@user:id"
self.get_success( # self.get_success(
self.store.insert_client_ip( # self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id" # user_id, "access_token", "ip", "user_agent", "device_id"
) # )
) # )
#
# Trigger the storage loop # # Trigger the storage loop
self.reactor.advance(10) # self.reactor.advance(10)
#
result = self.get_success( # result = self.get_success(
self.store.get_last_client_ip_by_device(user_id, "device_id") # self.store.get_last_client_ip_by_device(user_id, "device_id")
) # )
#
r = result[(user_id, "device_id")] # r = result[(user_id, "device_id")]
self.assertDictContainsSubset( # self.assertDictContainsSubset(
{ # {
"user_id": user_id, # "user_id": user_id,
"device_id": "device_id", # "device_id": "device_id",
"access_token": "access_token", # "access_token": "access_token",
"ip": "ip", # "ip": "ip",
"user_agent": "user_agent", # "user_agent": "user_agent",
"last_seen": 12345678000, # "last_seen": 12345678000,
}, # },
r, # r,
) # )
#
def test_disabled_monthly_active_user(self): # def test_disabled_monthly_active_user(self):
self.hs.config.limit_usage_by_mau = False # self.hs.config.limit_usage_by_mau = False
self.hs.config.max_mau_value = 50 # self.hs.config.max_mau_value = 50
user_id = "@user:server" # user_id = "@user:server"
self.get_success( # self.get_success(
self.store.insert_client_ip( # self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id" # user_id, "access_token", "ip", "user_agent", "device_id"
) # )
) # )
active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) # active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active) # self.assertFalse(active)
#
def test_adding_monthly_active_user_when_full(self): # def test_adding_monthly_active_user_when_full(self):
self.hs.config.limit_usage_by_mau = True # self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50 # self.hs.config.max_mau_value = 50
lots_of_users = 100 # lots_of_users = 100
user_id = "@user:server" # user_id = "@user:server"
#
self.store.get_monthly_active_count = Mock( # self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(lots_of_users) # return_value=defer.succeed(lots_of_users)
) # )
self.get_success( # self.get_success(
self.store.insert_client_ip( # self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id" # user_id, "access_token", "ip", "user_agent", "device_id"
) # )
) # )
active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) # active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active) # self.assertFalse(active)
#
def test_adding_monthly_active_user_when_space(self): # def test_adding_monthly_active_user_when_space(self):
self.hs.config.limit_usage_by_mau = True # self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50 # self.hs.config.max_mau_value = 50
user_id = "@user:server" # user_id = "@user:server"
active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) # active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active) # self.assertFalse(active)
#
# Trigger the saving loop # # Trigger the saving loop
self.reactor.advance(10) # self.reactor.advance(10)
#
self.get_success( # self.get_success(
self.store.insert_client_ip( # self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id" # user_id, "access_token", "ip", "user_agent", "device_id"
) # )
) # )
active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) # active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active) # self.assertTrue(active)
#
def test_updating_monthly_active_user_when_space(self): # def test_updating_monthly_active_user_when_space(self):
self.hs.config.limit_usage_by_mau = True # self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50 # self.hs.config.max_mau_value = 50
user_id = "@user:server" # user_id = "@user:server"
self.get_success( # self.get_success(
self.store.register(user_id=user_id, token="123", password_hash=None) # self.store.register(user_id=user_id, token="123", password_hash=None)
) # )
#
active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) # active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active) # self.assertFalse(active)
#
# Trigger the saving loop # # Trigger the saving loop
self.reactor.advance(10) # self.reactor.advance(10)
#
self.get_success( # self.get_success(
self.store.insert_client_ip( # self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id" # user_id, "access_token", "ip", "user_agent", "device_id"
) # )
) # )
active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) # active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active) # self.assertTrue(active)
class ClientIpAuthTestCase(unittest.HomeserverTestCase): class ClientIpAuthTestCase(unittest.HomeserverTestCase):

View file

@ -19,6 +19,7 @@ from twisted.internet import defer
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
FORTY_DAYS = 40 * 24 * 60 * 60 FORTY_DAYS = 40 * 24 * 60 * 60
ONE_HOUR = 60 *60
class MonthlyActiveUsersTestCase(HomeserverTestCase): class MonthlyActiveUsersTestCase(HomeserverTestCase):
@ -54,6 +55,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
self.store.user_add_threepid(user2, "email", user2_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now)
self.store.initialise_reserved_users(threepids) self.store.initialise_reserved_users(threepids)
self.pump() self.pump()
self.reactor.advance(ONE_HOUR)
active_count = self.store.get_monthly_active_count() active_count = self.store.get_monthly_active_count()
@ -81,7 +83,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
ru_count = 2 ru_count = 2
self.store.upsert_monthly_active_user("@ru1:server") self.store.upsert_monthly_active_user("@ru1:server")
self.store.upsert_monthly_active_user("@ru2:server") self.store.upsert_monthly_active_user("@ru2:server")
self.pump() self.reactor.advance(ONE_HOUR)
active_count = self.store.get_monthly_active_count() active_count = self.store.get_monthly_active_count()
self.assertEqual(self.get_success(active_count), user_num + ru_count) self.assertEqual(self.get_success(active_count), user_num + ru_count)
@ -94,12 +96,14 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
def test_can_insert_and_count_mau(self): def test_can_insert_and_count_mau(self):
count = self.store.get_monthly_active_count() count = self.store.get_monthly_active_count()
self.pump()
self.assertEqual(0, self.get_success(count)) self.assertEqual(0, self.get_success(count))
self.store.upsert_monthly_active_user("@user:server") self.store.upsert_monthly_active_user("@user:server")
self.pump() self.reactor.advance(ONE_HOUR)
count = self.store.get_monthly_active_count() count = self.store.get_monthly_active_count()
self.pump()
self.assertEqual(1, self.get_success(count)) self.assertEqual(1, self.get_success(count))
def test_user_last_seen_monthly_active(self): def test_user_last_seen_monthly_active(self):
@ -112,7 +116,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
self.store.upsert_monthly_active_user(user_id1) self.store.upsert_monthly_active_user(user_id1)
self.store.upsert_monthly_active_user(user_id2) self.store.upsert_monthly_active_user(user_id2)
self.pump() self.reactor.advance(ONE_HOUR)
result = self.store.user_last_seen_monthly_active(user_id1) result = self.store.user_last_seen_monthly_active(user_id1)
self.assertGreater(self.get_success(result), 0) self.assertGreater(self.get_success(result), 0)
@ -125,7 +129,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
initial_users = 10 initial_users = 10
for i in range(initial_users): for i in range(initial_users):
self.store.upsert_monthly_active_user("@user%d:server" % i) self.store.upsert_monthly_active_user("@user%d:server" % i)
self.pump() self.reactor.advance(ONE_HOUR)
count = self.store.get_monthly_active_count() count = self.store.get_monthly_active_count()
self.assertTrue(self.get_success(count), initial_users) self.assertTrue(self.get_success(count), initial_users)

View file

@ -16,6 +16,7 @@
"""Tests REST events for /rooms paths.""" """Tests REST events for /rooms paths."""
import json import json
import logging
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
@ -32,6 +33,9 @@ from tests.server import (
render, render,
setup_test_homeserver, setup_test_homeserver,
) )
logger = logging.getLogger(__name__)
ONE_HOUR = 60 * 60 * 1000
class TestMauLimit(unittest.TestCase): class TestMauLimit(unittest.TestCase):
@ -69,12 +73,15 @@ class TestMauLimit(unittest.TestCase):
sync.register_servlets(self.hs, self.resource) sync.register_servlets(self.hs, self.resource)
def test_simple_deny_mau(self): def test_simple_deny_mau(self):
# Create and sync so that the MAU counts get updated # Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1") token1 = self.create_user("kermit1")
logger.debug("create kermit1 token is %s" % token1)
self.do_sync_for_user(token1) self.do_sync_for_user(token1)
token2 = self.create_user("kermit2") token2 = self.create_user("kermit2")
self.do_sync_for_user(token2) self.do_sync_for_user(token2)
# Because adding to
self.reactor.advance(ONE_HOUR)
# We've created and activated two users, we shouldn't be able to # We've created and activated two users, we shouldn't be able to
# register new users # register new users
with self.assertRaises(SynapseError) as cm: with self.assertRaises(SynapseError) as cm:
@ -102,6 +109,7 @@ class TestMauLimit(unittest.TestCase):
token3 = self.create_user("kermit3") token3 = self.create_user("kermit3")
self.do_sync_for_user(token3) self.do_sync_for_user(token3)
@unittest.DEBUG
def test_trial_delay(self): def test_trial_delay(self):
self.hs.config.mau_trial_days = 1 self.hs.config.mau_trial_days = 1
@ -120,6 +128,8 @@ class TestMauLimit(unittest.TestCase):
self.do_sync_for_user(token1) self.do_sync_for_user(token1)
self.do_sync_for_user(token2) self.do_sync_for_user(token2)
self.reactor.advance(ONE_HOUR)
# But the third should fail # But the third should fail
with self.assertRaises(SynapseError) as cm: with self.assertRaises(SynapseError) as cm:
self.do_sync_for_user(token3) self.do_sync_for_user(token3)