Add database config class (#6513)

This encapsulates config for a given database and is the way to get new
connections.
This commit is contained in:
Erik Johnston 2019-12-18 10:45:12 +00:00 committed by GitHub
parent 91ccfe9f37
commit 2284eb3a53
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 286 additions and 208 deletions

1
changelog.d/6513.misc Normal file
View file

@ -0,0 +1 @@
Remove all assumptions of there being a single phyiscal DB apart from the `synapse.config`.

View file

@ -26,7 +26,6 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.prepare_database import prepare_database
logger = logging.getLogger("update_database") logger = logging.getLogger("update_database")
@ -77,12 +76,8 @@ if __name__ == "__main__":
# Instantiate and initialise the homeserver object. # Instantiate and initialise the homeserver object.
hs = MockHomeserver(config) hs = MockHomeserver(config)
db_conn = hs.get_db_conn() # Setup instantiates the store within the homeserver object and updates the
# Update the database to the latest schema. # DB.
prepare_database(db_conn, hs.database_engine, config=config)
db_conn.commit()
# setup instantiates the store within the homeserver object.
hs.setup() hs.setup()
store = hs.get_datastore() store = hs.get_datastore()

View file

@ -30,6 +30,7 @@ import yaml
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.logging.context import PreserveLoggingContext from synapse.logging.context import PreserveLoggingContext
from synapse.storage._base import LoggingTransaction from synapse.storage._base import LoggingTransaction
@ -55,7 +56,7 @@ from synapse.storage.data_stores.main.stats import StatsStore
from synapse.storage.data_stores.main.user_directory import ( from synapse.storage.data_stores.main.user_directory import (
UserDirectoryBackgroundUpdateStore, UserDirectoryBackgroundUpdateStore,
) )
from synapse.storage.database import Database from synapse.storage.database import Database, make_conn
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
from synapse.util import Clock from synapse.util import Clock
@ -165,23 +166,17 @@ class Store(
class MockHomeserver: class MockHomeserver:
def __init__(self, config, database_engine, db_conn, db_pool): def __init__(self, config):
self.database_engine = database_engine
self.db_conn = db_conn
self.db_pool = db_pool
self.clock = Clock(reactor) self.clock = Clock(reactor)
self.config = config self.config = config
self.hostname = config.server_name self.hostname = config.server_name
def get_db_conn(self):
return self.db_conn
def get_db_pool(self):
return self.db_pool
def get_clock(self): def get_clock(self):
return self.clock return self.clock
def get_reactor(self):
return reactor
class Porter(object): class Porter(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -445,45 +440,36 @@ class Porter(object):
else: else:
return return
def setup_db(self, db_config, database_engine): def setup_db(self, db_config: DatabaseConnectionConfig, engine):
db_conn = database_engine.module.connect( db_conn = make_conn(db_config, engine)
**{ prepare_database(db_conn, engine, config=None)
k: v
for k, v in db_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
prepare_database(db_conn, database_engine, config=None)
db_conn.commit() db_conn.commit()
return db_conn return db_conn
@defer.inlineCallbacks @defer.inlineCallbacks
def build_db_store(self, config): def build_db_store(self, db_config: DatabaseConnectionConfig):
"""Builds and returns a database store using the provided configuration. """Builds and returns a database store using the provided configuration.
Args: Args:
config: The database configuration, i.e. a dict following the structure of config: The database configuration
the "database" section of Synapse's configuration file.
Returns: Returns:
The built Store object. The built Store object.
""" """
engine = create_engine(config) self.progress.set_state("Preparing %s" % db_config.config["name"])
self.progress.set_state("Preparing %s" % config["name"]) engine = create_engine(db_config.config)
conn = self.setup_db(config, engine) conn = self.setup_db(db_config, engine)
db_pool = adbapi.ConnectionPool(config["name"], **config["args"]) hs = MockHomeserver(self.hs_config)
hs = MockHomeserver(self.hs_config, engine, conn, db_pool) store = Store(Database(hs, db_config, engine), conn, hs)
store = Store(Database(hs), conn, hs)
yield store.db.runInteraction( yield store.db.runInteraction(
"%s_engine.check_database" % config["name"], engine.check_database, "%s_engine.check_database" % db_config.config["name"],
engine.check_database,
) )
return store return store
@ -509,7 +495,11 @@ class Porter(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def run(self): def run(self):
try: try:
self.sqlite_store = yield self.build_db_store(self.sqlite_config) self.sqlite_store = yield self.build_db_store(
DatabaseConnectionConfig(
"master", self.sqlite_config, data_stores=["main"]
)
)
# Check if all background updates are done, abort if not. # Check if all background updates are done, abort if not.
updates_complete = ( updates_complete = (
@ -524,7 +514,7 @@ class Porter(object):
defer.returnValue(None) defer.returnValue(None)
self.postgres_store = yield self.build_db_store( self.postgres_store = yield self.build_db_store(
self.hs_config.database_config self.hs_config.get_single_database()
) )
yield self.run_background_updates_on_postgres() yield self.run_background_updates_on_postgres()

View file

@ -12,12 +12,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import os import os
from textwrap import indent from textwrap import indent
from typing import List
import yaml import yaml
from ._base import Config from synapse.config._base import Config, ConfigError
logger = logging.getLogger(__name__)
class DatabaseConnectionConfig:
"""Contains the connection config for a particular database.
Args:
name: A label for the database, used for logging.
db_config: The config for a particular database, as per `database`
section of main config. Has two fields: `name` for database
module name, and `args` for the args to give to the database
connector.
data_stores: The list of data stores that should be provisioned on the
database.
"""
def __init__(self, name: str, db_config: dict, data_stores: List[str]):
if db_config["name"] not in ("sqlite3", "psycopg2"):
raise ConfigError("Unsupported database type %r" % (db_config["name"],))
if db_config["name"] == "sqlite3":
db_config.setdefault("args", {}).update(
{"cp_min": 1, "cp_max": 1, "check_same_thread": False}
)
self.name = name
self.config = db_config
self.data_stores = data_stores
class DatabaseConfig(Config): class DatabaseConfig(Config):
@ -26,20 +57,14 @@ class DatabaseConfig(Config):
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
self.database_config = config.get("database") database_config = config.get("database")
if self.database_config is None: if database_config is None:
self.database_config = {"name": "sqlite3", "args": {}} database_config = {"name": "sqlite3", "args": {}}
name = self.database_config.get("name", None) self.databases = [
if name == "psycopg2": DatabaseConnectionConfig("master", database_config, data_stores=["main"])
pass ]
elif name == "sqlite3":
self.database_config.setdefault("args", {}).update(
{"cp_min": 1, "cp_max": 1, "check_same_thread": False}
)
else:
raise RuntimeError("Unsupported database type '%s'" % (name,))
self.set_databasepath(config.get("database_path")) self.set_databasepath(config.get("database_path"))
@ -76,11 +101,24 @@ class DatabaseConfig(Config):
self.set_databasepath(args.database_path) self.set_databasepath(args.database_path)
def set_databasepath(self, database_path): def set_databasepath(self, database_path):
if database_path is None:
return
if database_path != ":memory:": if database_path != ":memory:":
database_path = self.abspath(database_path) database_path = self.abspath(database_path)
if self.database_config.get("name", None) == "sqlite3":
if database_path is not None: # We only support setting a database path if we have a single sqlite3
self.database_config["args"]["database"] = database_path # database.
if len(self.databases) != 1:
raise ConfigError("Cannot specify 'database_path' with multiple databases")
database = self.get_single_database()
if database.config["name"] != "sqlite3":
# We don't raise here as we haven't done so before for this case.
logger.warn("Ignoring 'database_path' for non-sqlite3 database")
return
database.config["args"]["database"] = database_path
@staticmethod @staticmethod
def add_arguments(parser): def add_arguments(parser):
@ -91,3 +129,11 @@ class DatabaseConfig(Config):
metavar="SQLITE_DATABASE_PATH", metavar="SQLITE_DATABASE_PATH",
help="The path to a sqlite database to use.", help="The path to a sqlite database to use.",
) )
def get_single_database(self) -> DatabaseConnectionConfig:
"""Returns the database if there is only one, useful for e.g. tests
"""
if len(self.databases) != 1:
raise Exception("More than one database exists")
return self.databases[0]

View file

@ -230,7 +230,7 @@ class PresenceHandler(object):
is some spurious presence changes that will self-correct. is some spurious presence changes that will self-correct.
""" """
# If the DB pool has already terminated, don't try updating # If the DB pool has already terminated, don't try updating
if not self.hs.get_db_pool().running: if not self.store.database.is_running():
return return
logger.info( logger.info(

View file

@ -25,7 +25,6 @@ import abc
import logging import logging
import os import os
from twisted.enterprise import adbapi
from twisted.mail.smtp import sendmail from twisted.mail.smtp import sendmail
from twisted.web.client import BrowserLikePolicyForHTTPS from twisted.web.client import BrowserLikePolicyForHTTPS
@ -98,7 +97,6 @@ from synapse.server_notices.worker_server_notices_sender import (
) )
from synapse.state import StateHandler, StateResolutionHandler from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStores, Storage from synapse.storage import DataStores, Storage
from synapse.storage.engines import create_engine
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.util import Clock from synapse.util import Clock
from synapse.util.distributor import Distributor from synapse.util.distributor import Distributor
@ -134,7 +132,6 @@ class HomeServer(object):
DEPENDENCIES = [ DEPENDENCIES = [
"http_client", "http_client",
"db_pool",
"federation_client", "federation_client",
"federation_server", "federation_server",
"handlers", "handlers",
@ -233,12 +230,6 @@ class HomeServer(object):
self.admin_redaction_ratelimiter = Ratelimiter() self.admin_redaction_ratelimiter = Ratelimiter()
self.registration_ratelimiter = Ratelimiter() self.registration_ratelimiter = Ratelimiter()
self.database_engine = create_engine(config.database_config)
config.database_config.setdefault("args", {})[
"cp_openfun"
] = self.database_engine.on_new_connection
self.db_config = config.database_config
self.datastores = None self.datastores = None
# Other kwargs are explicit dependencies # Other kwargs are explicit dependencies
@ -247,10 +238,8 @@ class HomeServer(object):
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
with self.get_db_conn() as conn:
self.datastores = DataStores(self.DATASTORE_CLASS, conn, self)
conn.commit()
self.start_time = int(self.get_clock().time()) self.start_time = int(self.get_clock().time())
self.datastores = DataStores(self.DATASTORE_CLASS, self)
logger.info("Finished setting up.") logger.info("Finished setting up.")
def setup_master(self): def setup_master(self):
@ -284,6 +273,9 @@ class HomeServer(object):
def get_datastore(self): def get_datastore(self):
return self.datastores.main return self.datastores.main
def get_datastores(self):
return self.datastores
def get_config(self): def get_config(self):
return self.config return self.config
@ -433,31 +425,6 @@ class HomeServer(object):
) )
return MatrixFederationHttpClient(self, tls_client_options_factory) return MatrixFederationHttpClient(self, tls_client_options_factory)
def build_db_pool(self):
name = self.db_config["name"]
return adbapi.ConnectionPool(
name, cp_reactor=self.get_reactor(), **self.db_config.get("args", {})
)
def get_db_conn(self, run_new_connection=True):
"""Makes a new connection to the database, skipping the db pool
Returns:
Connection: a connection object implementing the PEP-249 spec
"""
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v
for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def build_media_repository_resource(self): def build_media_repository_resource(self):
# build the media repo resource. This indirects through the HomeServer # build the media repo resource. This indirects through the HomeServer
# to ensure that we only have a single instance of # to ensure that we only have a single instance of

View file

@ -40,7 +40,7 @@ class SQLBaseStore(object):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.database_engine = hs.database_engine self.database_engine = database.engine
self.db = database self.db = database
self.rand = random.SystemRandom() self.rand = random.SystemRandom()

View file

@ -13,24 +13,50 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.database import Database import logging
from synapse.storage.database import Database, make_conn
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
logger = logging.getLogger(__name__)
class DataStores(object): class DataStores(object):
"""The various data stores. """The various data stores.
These are low level interfaces to physical databases. These are low level interfaces to physical databases.
Attributes:
main (DataStore)
""" """
def __init__(self, main_store_class, db_conn, hs): def __init__(self, main_store_class, hs):
# Note we pass in the main store class here as workers use a different main # Note we pass in the main store class here as workers use a different main
# store. # store.
database = Database(hs)
# Check that db is correctly configured. self.databases = []
database.engine.check_database(db_conn.cursor())
prepare_database(db_conn, database.engine, config=hs.config) for database_config in hs.config.database.databases:
db_name = database_config.name
engine = create_engine(database_config.config)
with make_conn(database_config, engine) as db_conn:
logger.info("Preparing database %r...", db_name)
engine.check_database(db_conn.cursor())
prepare_database(
db_conn, engine, hs.config, data_stores=database_config.data_stores,
)
database = Database(hs, database_config, engine)
if "main" in database_config.data_stores:
logger.info("Starting 'main' data store")
self.main = main_store_class(database, db_conn, hs) self.main = main_store_class(database, db_conn, hs)
db_conn.commit()
self.databases.append(database)
logger.info("Database %r prepared", db_name)

View file

@ -412,7 +412,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
def _update_client_ips_batch(self): def _update_client_ips_batch(self):
# If the DB pool has already terminated, don't try updating # If the DB pool has already terminated, don't try updating
if not self.hs.get_db_pool().running: if not self.db.is_running():
return return
to_update = self._batch_row_update to_update = self._batch_row_update

View file

@ -24,9 +24,11 @@ from six.moves import intern, range
from prometheus_client import Histogram from prometheus_client import Histogram
from twisted.enterprise import adbapi
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import BackgroundUpdater
@ -74,6 +76,37 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
} }
def make_pool(
reactor, db_config: DatabaseConnectionConfig, engine
) -> adbapi.ConnectionPool:
"""Get the connection pool for the database.
"""
return adbapi.ConnectionPool(
db_config.config["name"],
cp_reactor=reactor,
cp_openfun=engine.on_new_connection,
**db_config.config.get("args", {})
)
def make_conn(db_config: DatabaseConnectionConfig, engine):
"""Make a new connection to the database and return it.
Returns:
Connection
"""
db_params = {
k: v
for k, v in db_config.config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = engine.module.connect(**db_params)
engine.on_new_connection(db_conn)
return db_conn
class LoggingTransaction(object): class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute() passed to the constructor. Adds logging and metrics to the .execute()
@ -218,10 +251,11 @@ class Database(object):
_TXN_ID = 0 _TXN_ID = 0
def __init__(self, hs): def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
self._db_pool = hs.get_db_pool() self._database_config = database_config
self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
self.updates = BackgroundUpdater(hs, self) self.updates = BackgroundUpdater(hs, self)
@ -234,7 +268,7 @@ class Database(object):
# to watch it # to watch it
self._txn_perf_counters = PerformanceCounters() self._txn_perf_counters = PerformanceCounters()
self.engine = hs.database_engine self.engine = engine
# A set of tables that are not safe to use native upserts in. # A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
@ -255,6 +289,11 @@ class Database(object):
self._check_safe_to_upsert, self._check_safe_to_upsert,
) )
def is_running(self):
"""Is the database pool currently running
"""
return self._db_pool.running
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_safe_to_upsert(self): def _check_safe_to_upsert(self):
""" """

View file

@ -16,8 +16,6 @@
import struct import struct
import threading import threading
from synapse.storage.prepare_database import prepare_database
class Sqlite3Engine(object): class Sqlite3Engine(object):
single_threaded = True single_threaded = True
@ -25,6 +23,9 @@ class Sqlite3Engine(object):
def __init__(self, database_module, database_config): def __init__(self, database_module, database_config):
self.module = database_module self.module = database_module
database = database_config.get("args", {}).get("database")
self._is_in_memory = database in (None, ":memory:",)
# The current max state_group, or None if we haven't looked # The current max state_group, or None if we haven't looked
# in the DB yet. # in the DB yet.
self._current_state_group_id = None self._current_state_group_id = None
@ -59,7 +60,16 @@ class Sqlite3Engine(object):
return sql return sql
def on_new_connection(self, db_conn): def on_new_connection(self, db_conn):
# We need to import here to avoid an import loop.
from synapse.storage.prepare_database import prepare_database
if self._is_in_memory:
# In memory databases need to be rebuilt each time. Ideally we'd
# reuse the same connection as we do when starting up, but that
# would involve using adbapi before we have started the reactor.
prepare_database(db_conn, self, config=None) prepare_database(db_conn, self, config=None)
db_conn.create_function("rank", 1, _rank) db_conn.create_function("rank", 1, _rank)
def is_deadlock(self, error): def is_deadlock(self, error):

View file

@ -41,7 +41,7 @@ class UpgradeDatabaseException(PrepareDatabaseException):
pass pass
def prepare_database(db_conn, database_engine, config): def prepare_database(db_conn, database_engine, config, data_stores=["main"]):
"""Prepares a database for usage. Will either create all necessary tables """Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version. or upgrade from an older schema version.
@ -54,11 +54,10 @@ def prepare_database(db_conn, database_engine, config):
config (synapse.config.homeserver.HomeServerConfig|None): config (synapse.config.homeserver.HomeServerConfig|None):
application config, or None if we are connecting to an existing application config, or None if we are connecting to an existing
database which we expect to be configured already database which we expect to be configured already
data_stores (list[str]): The name of the data stores that will be used
with this database. Defaults to all data stores.
""" """
# For now we only have the one datastore.
data_stores = ["main"]
try: try:
cur = db_conn.cursor() cur = db_conn.cursor()
version_info = _get_or_create_schema_state(cur, database_engine) version_info = _get_or_create_schema_state(cur, database_engine)

View file

@ -64,9 +64,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
mock_federation_client = Mock(spec=["put_json"]) mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
hs = self.setup_test_homeserver( datastores = Mock()
datastore=( datastores.main = Mock(
Mock(
spec=[ spec=[
# Bits that Federation needs # Bits that Federation needs
"prep_send_transaction", "prep_send_transaction",
@ -74,18 +73,20 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"get_received_txn_response", "get_received_txn_response",
"set_received_txn_response", "set_received_txn_response",
"get_destination_retry_timings", "get_destination_retry_timings",
"get_device_updates_by_remote", "get_devices_by_remote",
# Bits that user_directory needs # Bits that user_directory needs
"get_user_directory_stream_pos", "get_user_directory_stream_pos",
"get_current_state_deltas", "get_current_state_deltas",
"get_device_updates_by_remote",
] ]
) )
),
notifier=Mock(), hs = self.setup_test_homeserver(
http_client=mock_federation_client, notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring
keyring=mock_keyring,
) )
hs.datastores = datastores
return hs return hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):

View file

@ -20,7 +20,7 @@ from synapse.replication.tcp.client import (
ReplicationClientHandler, ReplicationClientHandler,
) )
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.storage.database import Database from synapse.storage.database import make_conn
from tests import unittest from tests import unittest
from tests.server import FakeTransport from tests.server import FakeTransport
@ -41,10 +41,12 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
db_config = hs.config.database.get_single_database()
self.master_store = self.hs.get_datastore() self.master_store = self.hs.get_datastore()
self.storage = hs.get_storage() self.storage = hs.get_storage()
database = hs.get_datastores().databases[0]
self.slaved_store = self.STORE_TYPE( self.slaved_store = self.STORE_TYPE(
Database(hs), self.hs.get_db_conn(), self.hs database, make_conn(db_config, database.engine), self.hs
) )
self.event_id = 0 self.event_id = 0

View file

@ -302,14 +302,15 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
Set up a synchronous test server, driven by the reactor used by Set up a synchronous test server, driven by the reactor used by
the homeserver. the homeserver.
""" """
d = _sth(cleanup_func, *args, **kwargs).result server = _sth(cleanup_func, *args, **kwargs)
if isinstance(d, Failure): database = server.config.database.get_single_database()
d.raiseException()
# Make the thread pool synchronous. # Make the thread pool synchronous.
clock = d.get_clock() clock = server.get_clock()
pool = d.get_db_pool()
for database in server.get_datastores().databases:
pool = database._db_pool
def runWithConnection(func, *args, **kwargs): def runWithConnection(func, *args, **kwargs):
return threads.deferToThreadPool( return threads.deferToThreadPool(
@ -331,12 +332,12 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
**kwargs **kwargs
) )
if pool:
pool.runWithConnection = runWithConnection pool.runWithConnection = runWithConnection
pool.runInteraction = runInteraction pool.runInteraction = runInteraction
pool.threadpool = ThreadPool(clock._reactor) pool.threadpool = ThreadPool(clock._reactor)
pool.running = True pool.running = True
return d
return server
def get_clock(): def get_clock():

View file

@ -28,7 +28,7 @@ from synapse.storage.data_stores.main.appservice import (
ApplicationServiceStore, ApplicationServiceStore,
ApplicationServiceTransactionStore, ApplicationServiceTransactionStore,
) )
from synapse.storage.database import Database from synapse.storage.database import Database, make_conn
from tests import unittest from tests import unittest
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
@ -55,8 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts # must be done after inserts
database = Database(hs) database = hs.get_datastores().databases[0]
self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs) self.store = ApplicationServiceStore(
database, make_conn(database._database_config, database.engine), hs
)
def tearDown(self): def tearDown(self):
# TODO: suboptimal that we need to create files for tests! # TODO: suboptimal that we need to create files for tests!
@ -111,9 +113,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
hs.config.event_cache_size = 1 hs.config.event_cache_size = 1
hs.config.password_providers = [] hs.config.password_providers = []
self.db_pool = hs.get_db_pool()
self.engine = hs.database_engine
self.as_list = [ self.as_list = [
{"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
{"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"}, {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"},
@ -125,8 +124,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files = [] self.as_yaml_files = []
database = Database(hs) # We assume there is only one database in these tests
self.store = TestTransactionStore(database, hs.get_db_conn(), hs) database = hs.get_datastores().databases[0]
self.db_pool = database._db_pool
self.engine = database.engine
db_config = hs.config.get_single_database()
self.store = TestTransactionStore(
database, make_conn(db_config, self.engine), hs
)
def _add_service(self, url, as_token, id): def _add_service(self, url, as_token, id):
as_yaml = dict( as_yaml = dict(
@ -419,7 +425,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.event_cache_size = 1 hs.config.event_cache_size = 1
hs.config.password_providers = [] hs.config.password_providers = []
ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) database = hs.get_datastores().databases[0]
ApplicationServiceStore(
database, make_conn(database._database_config, database.engine), hs
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_duplicate_ids(self): def test_duplicate_ids(self):
@ -435,7 +444,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.password_providers = [] hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) database = hs.get_datastores().databases[0]
ApplicationServiceStore(
database, make_conn(database._database_config, database.engine), hs
)
e = cm.exception e = cm.exception
self.assertIn(f1, str(e)) self.assertIn(f1, str(e))
@ -456,7 +468,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.password_providers = [] hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) database = hs.get_datastores().databases[0]
ApplicationServiceStore(
database, make_conn(database._database_config, database.engine), hs
)
e = cm.exception e = cm.exception
self.assertIn(f1, str(e)) self.assertIn(f1, str(e))

View file

@ -52,15 +52,17 @@ class SQLBaseStoreTestCase(unittest.TestCase):
config = Mock() config = Mock()
config._disable_native_upserts = True config._disable_native_upserts = True
config.event_cache_size = 1 config.event_cache_size = 1
config.database_config = {"name": "sqlite3"} hs = TestHomeServer("test", config=config)
engine = create_engine(config.database_config)
sqlite_config = {"name": "sqlite3"}
engine = create_engine(sqlite_config)
fake_engine = Mock(wraps=engine) fake_engine = Mock(wraps=engine)
fake_engine.can_native_upsert = False fake_engine.can_native_upsert = False
hs = TestHomeServer(
"test", db_pool=self.db_pool, config=config, database_engine=fake_engine
)
self.datastore = SQLBaseStore(Database(hs), None, hs) db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
db._db_pool = self.db_pool
self.datastore = SQLBaseStore(db, None, hs)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_insert_1col(self): def test_insert_1col(self):

View file

@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup) hs = yield setup_test_homeserver(self.addCleanup)
self.db_pool = hs.get_db_pool()
self.store = hs.get_datastore() self.store = hs.get_datastore()

View file

@ -30,6 +30,7 @@ from twisted.internet import defer, reactor
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.federation.transport import server as federation_server from synapse.federation.transport import server as federation_server
@ -177,7 +178,6 @@ class TestHomeServer(HomeServer):
DATASTORE_CLASS = DataStore DATASTORE_CLASS = DataStore
@defer.inlineCallbacks
def setup_test_homeserver( def setup_test_homeserver(
cleanup_func, cleanup_func,
name="test", name="test",
@ -214,7 +214,7 @@ def setup_test_homeserver(
if USE_POSTGRES_FOR_TESTS: if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex test_db = "synapse_test_%s" % uuid.uuid4().hex
config.database_config = { database_config = {
"name": "psycopg2", "name": "psycopg2",
"args": { "args": {
"database": test_db, "database": test_db,
@ -226,12 +226,15 @@ def setup_test_homeserver(
}, },
} }
else: else:
config.database_config = { database_config = {
"name": "sqlite3", "name": "sqlite3",
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
} }
db_engine = create_engine(config.database_config) database = DatabaseConnectionConfig("master", database_config, ["main"])
config.database.databases = [database]
db_engine = create_engine(database.config)
# Create the database before we actually try and connect to it, based off # Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb() # the template database we generate in setupdb()
@ -251,11 +254,6 @@ def setup_test_homeserver(
cur.close() cur.close()
db_conn.close() db_conn.close()
# we need to configure the connection pool to run the on_new_connection
# function, so that we can test code that uses custom sqlite functions
# (like rank).
config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
if datastore is None: if datastore is None:
hs = homeserverToUse( hs = homeserverToUse(
name, name,
@ -267,21 +265,19 @@ def setup_test_homeserver(
**kargs **kargs
) )
# Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to hs.setup()
# date db if homeserverToUse.__name__ == "TestHomeServer":
if not isinstance(db_engine, PostgresEngine): hs.setup_master()
db_conn = hs.get_db_conn()
yield prepare_database(db_conn, db_engine, config) if isinstance(db_engine, PostgresEngine):
db_conn.commit() database = hs.get_datastores().databases[0]
db_conn.close()
else:
# We need to do cleanup on PostgreSQL # We need to do cleanup on PostgreSQL
def cleanup(): def cleanup():
import psycopg2 import psycopg2
# Close all the db pools # Close all the db pools
hs.get_db_pool().close() database._db_pool.close()
dropped = False dropped = False
@ -320,23 +316,12 @@ def setup_test_homeserver(
# Register the cleanup hook # Register the cleanup hook
cleanup_func(cleanup) cleanup_func(cleanup)
hs.setup()
if homeserverToUse.__name__ == "TestHomeServer":
hs.setup_master()
else: else:
# If we have been given an explicit datastore we probably want to mock
# out the DataStores somehow too. This all feels a bit wrong, but then
# mocking the stores feels wrong too.
datastores = Mock(datastore=datastore)
hs = homeserverToUse( hs = homeserverToUse(
name, name,
db_pool=None,
datastore=datastore, datastore=datastore,
datastores=datastores,
config=config, config=config,
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=db_engine,
tls_server_context_factory=Mock(), tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(), tls_client_options_factory=Mock(),
reactor=reactor, reactor=reactor,