mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-25 02:55:46 +03:00
Add database config class (#6513)
This encapsulates config for a given database and is the way to get new connections.
This commit is contained in:
parent
91ccfe9f37
commit
2284eb3a53
19 changed files with 286 additions and 208 deletions
1
changelog.d/6513.misc
Normal file
1
changelog.d/6513.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Remove all assumptions of there being a single phyiscal DB apart from the `synapse.config`.
|
|
@ -26,7 +26,6 @@ from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
from synapse.storage.prepare_database import prepare_database
|
|
||||||
|
|
||||||
logger = logging.getLogger("update_database")
|
logger = logging.getLogger("update_database")
|
||||||
|
|
||||||
|
@ -77,12 +76,8 @@ if __name__ == "__main__":
|
||||||
# Instantiate and initialise the homeserver object.
|
# Instantiate and initialise the homeserver object.
|
||||||
hs = MockHomeserver(config)
|
hs = MockHomeserver(config)
|
||||||
|
|
||||||
db_conn = hs.get_db_conn()
|
# Setup instantiates the store within the homeserver object and updates the
|
||||||
# Update the database to the latest schema.
|
# DB.
|
||||||
prepare_database(db_conn, hs.database_engine, config=config)
|
|
||||||
db_conn.commit()
|
|
||||||
|
|
||||||
# setup instantiates the store within the homeserver object.
|
|
||||||
hs.setup()
|
hs.setup()
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ import yaml
|
||||||
from twisted.enterprise import adbapi
|
from twisted.enterprise import adbapi
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
|
|
||||||
|
from synapse.config.database import DatabaseConnectionConfig
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.logging.context import PreserveLoggingContext
|
from synapse.logging.context import PreserveLoggingContext
|
||||||
from synapse.storage._base import LoggingTransaction
|
from synapse.storage._base import LoggingTransaction
|
||||||
|
@ -55,7 +56,7 @@ from synapse.storage.data_stores.main.stats import StatsStore
|
||||||
from synapse.storage.data_stores.main.user_directory import (
|
from synapse.storage.data_stores.main.user_directory import (
|
||||||
UserDirectoryBackgroundUpdateStore,
|
UserDirectoryBackgroundUpdateStore,
|
||||||
)
|
)
|
||||||
from synapse.storage.database import Database
|
from synapse.storage.database import Database, make_conn
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.storage.prepare_database import prepare_database
|
from synapse.storage.prepare_database import prepare_database
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
@ -165,23 +166,17 @@ class Store(
|
||||||
|
|
||||||
|
|
||||||
class MockHomeserver:
|
class MockHomeserver:
|
||||||
def __init__(self, config, database_engine, db_conn, db_pool):
|
def __init__(self, config):
|
||||||
self.database_engine = database_engine
|
|
||||||
self.db_conn = db_conn
|
|
||||||
self.db_pool = db_pool
|
|
||||||
self.clock = Clock(reactor)
|
self.clock = Clock(reactor)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.hostname = config.server_name
|
self.hostname = config.server_name
|
||||||
|
|
||||||
def get_db_conn(self):
|
|
||||||
return self.db_conn
|
|
||||||
|
|
||||||
def get_db_pool(self):
|
|
||||||
return self.db_pool
|
|
||||||
|
|
||||||
def get_clock(self):
|
def get_clock(self):
|
||||||
return self.clock
|
return self.clock
|
||||||
|
|
||||||
|
def get_reactor(self):
|
||||||
|
return reactor
|
||||||
|
|
||||||
|
|
||||||
class Porter(object):
|
class Porter(object):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
@ -445,45 +440,36 @@ class Porter(object):
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
def setup_db(self, db_config, database_engine):
|
def setup_db(self, db_config: DatabaseConnectionConfig, engine):
|
||||||
db_conn = database_engine.module.connect(
|
db_conn = make_conn(db_config, engine)
|
||||||
**{
|
prepare_database(db_conn, engine, config=None)
|
||||||
k: v
|
|
||||||
for k, v in db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
prepare_database(db_conn, database_engine, config=None)
|
|
||||||
|
|
||||||
db_conn.commit()
|
db_conn.commit()
|
||||||
|
|
||||||
return db_conn
|
return db_conn
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def build_db_store(self, config):
|
def build_db_store(self, db_config: DatabaseConnectionConfig):
|
||||||
"""Builds and returns a database store using the provided configuration.
|
"""Builds and returns a database store using the provided configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: The database configuration, i.e. a dict following the structure of
|
config: The database configuration
|
||||||
the "database" section of Synapse's configuration file.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The built Store object.
|
The built Store object.
|
||||||
"""
|
"""
|
||||||
engine = create_engine(config)
|
self.progress.set_state("Preparing %s" % db_config.config["name"])
|
||||||
|
|
||||||
self.progress.set_state("Preparing %s" % config["name"])
|
engine = create_engine(db_config.config)
|
||||||
conn = self.setup_db(config, engine)
|
conn = self.setup_db(db_config, engine)
|
||||||
|
|
||||||
db_pool = adbapi.ConnectionPool(config["name"], **config["args"])
|
hs = MockHomeserver(self.hs_config)
|
||||||
|
|
||||||
hs = MockHomeserver(self.hs_config, engine, conn, db_pool)
|
store = Store(Database(hs, db_config, engine), conn, hs)
|
||||||
|
|
||||||
store = Store(Database(hs), conn, hs)
|
|
||||||
|
|
||||||
yield store.db.runInteraction(
|
yield store.db.runInteraction(
|
||||||
"%s_engine.check_database" % config["name"], engine.check_database,
|
"%s_engine.check_database" % db_config.config["name"],
|
||||||
|
engine.check_database,
|
||||||
)
|
)
|
||||||
|
|
||||||
return store
|
return store
|
||||||
|
@ -509,7 +495,11 @@ class Porter(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def run(self):
|
def run(self):
|
||||||
try:
|
try:
|
||||||
self.sqlite_store = yield self.build_db_store(self.sqlite_config)
|
self.sqlite_store = yield self.build_db_store(
|
||||||
|
DatabaseConnectionConfig(
|
||||||
|
"master", self.sqlite_config, data_stores=["main"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Check if all background updates are done, abort if not.
|
# Check if all background updates are done, abort if not.
|
||||||
updates_complete = (
|
updates_complete = (
|
||||||
|
@ -524,7 +514,7 @@ class Porter(object):
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
self.postgres_store = yield self.build_db_store(
|
self.postgres_store = yield self.build_db_store(
|
||||||
self.hs_config.database_config
|
self.hs_config.get_single_database()
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.run_background_updates_on_postgres()
|
yield self.run_background_updates_on_postgres()
|
||||||
|
|
|
@ -12,12 +12,43 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from textwrap import indent
|
from textwrap import indent
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from ._base import Config
|
from synapse.config._base import Config, ConfigError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseConnectionConfig:
|
||||||
|
"""Contains the connection config for a particular database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: A label for the database, used for logging.
|
||||||
|
db_config: The config for a particular database, as per `database`
|
||||||
|
section of main config. Has two fields: `name` for database
|
||||||
|
module name, and `args` for the args to give to the database
|
||||||
|
connector.
|
||||||
|
data_stores: The list of data stores that should be provisioned on the
|
||||||
|
database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: str, db_config: dict, data_stores: List[str]):
|
||||||
|
if db_config["name"] not in ("sqlite3", "psycopg2"):
|
||||||
|
raise ConfigError("Unsupported database type %r" % (db_config["name"],))
|
||||||
|
|
||||||
|
if db_config["name"] == "sqlite3":
|
||||||
|
db_config.setdefault("args", {}).update(
|
||||||
|
{"cp_min": 1, "cp_max": 1, "check_same_thread": False}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.name = name
|
||||||
|
self.config = db_config
|
||||||
|
self.data_stores = data_stores
|
||||||
|
|
||||||
|
|
||||||
class DatabaseConfig(Config):
|
class DatabaseConfig(Config):
|
||||||
|
@ -26,20 +57,14 @@ class DatabaseConfig(Config):
|
||||||
def read_config(self, config, **kwargs):
|
def read_config(self, config, **kwargs):
|
||||||
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
|
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
|
||||||
|
|
||||||
self.database_config = config.get("database")
|
database_config = config.get("database")
|
||||||
|
|
||||||
if self.database_config is None:
|
if database_config is None:
|
||||||
self.database_config = {"name": "sqlite3", "args": {}}
|
database_config = {"name": "sqlite3", "args": {}}
|
||||||
|
|
||||||
name = self.database_config.get("name", None)
|
self.databases = [
|
||||||
if name == "psycopg2":
|
DatabaseConnectionConfig("master", database_config, data_stores=["main"])
|
||||||
pass
|
]
|
||||||
elif name == "sqlite3":
|
|
||||||
self.database_config.setdefault("args", {}).update(
|
|
||||||
{"cp_min": 1, "cp_max": 1, "check_same_thread": False}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Unsupported database type '%s'" % (name,))
|
|
||||||
|
|
||||||
self.set_databasepath(config.get("database_path"))
|
self.set_databasepath(config.get("database_path"))
|
||||||
|
|
||||||
|
@ -76,11 +101,24 @@ class DatabaseConfig(Config):
|
||||||
self.set_databasepath(args.database_path)
|
self.set_databasepath(args.database_path)
|
||||||
|
|
||||||
def set_databasepath(self, database_path):
|
def set_databasepath(self, database_path):
|
||||||
|
if database_path is None:
|
||||||
|
return
|
||||||
|
|
||||||
if database_path != ":memory:":
|
if database_path != ":memory:":
|
||||||
database_path = self.abspath(database_path)
|
database_path = self.abspath(database_path)
|
||||||
if self.database_config.get("name", None) == "sqlite3":
|
|
||||||
if database_path is not None:
|
# We only support setting a database path if we have a single sqlite3
|
||||||
self.database_config["args"]["database"] = database_path
|
# database.
|
||||||
|
if len(self.databases) != 1:
|
||||||
|
raise ConfigError("Cannot specify 'database_path' with multiple databases")
|
||||||
|
|
||||||
|
database = self.get_single_database()
|
||||||
|
if database.config["name"] != "sqlite3":
|
||||||
|
# We don't raise here as we haven't done so before for this case.
|
||||||
|
logger.warn("Ignoring 'database_path' for non-sqlite3 database")
|
||||||
|
return
|
||||||
|
|
||||||
|
database.config["args"]["database"] = database_path
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_arguments(parser):
|
def add_arguments(parser):
|
||||||
|
@ -91,3 +129,11 @@ class DatabaseConfig(Config):
|
||||||
metavar="SQLITE_DATABASE_PATH",
|
metavar="SQLITE_DATABASE_PATH",
|
||||||
help="The path to a sqlite database to use.",
|
help="The path to a sqlite database to use.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_single_database(self) -> DatabaseConnectionConfig:
|
||||||
|
"""Returns the database if there is only one, useful for e.g. tests
|
||||||
|
"""
|
||||||
|
if len(self.databases) != 1:
|
||||||
|
raise Exception("More than one database exists")
|
||||||
|
|
||||||
|
return self.databases[0]
|
||||||
|
|
|
@ -230,7 +230,7 @@ class PresenceHandler(object):
|
||||||
is some spurious presence changes that will self-correct.
|
is some spurious presence changes that will self-correct.
|
||||||
"""
|
"""
|
||||||
# If the DB pool has already terminated, don't try updating
|
# If the DB pool has already terminated, don't try updating
|
||||||
if not self.hs.get_db_pool().running:
|
if not self.store.database.is_running():
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
@ -25,7 +25,6 @@ import abc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from twisted.enterprise import adbapi
|
|
||||||
from twisted.mail.smtp import sendmail
|
from twisted.mail.smtp import sendmail
|
||||||
from twisted.web.client import BrowserLikePolicyForHTTPS
|
from twisted.web.client import BrowserLikePolicyForHTTPS
|
||||||
|
|
||||||
|
@ -98,7 +97,6 @@ from synapse.server_notices.worker_server_notices_sender import (
|
||||||
)
|
)
|
||||||
from synapse.state import StateHandler, StateResolutionHandler
|
from synapse.state import StateHandler, StateResolutionHandler
|
||||||
from synapse.storage import DataStores, Storage
|
from synapse.storage import DataStores, Storage
|
||||||
from synapse.storage.engines import create_engine
|
|
||||||
from synapse.streams.events import EventSources
|
from synapse.streams.events import EventSources
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
from synapse.util.distributor import Distributor
|
from synapse.util.distributor import Distributor
|
||||||
|
@ -134,7 +132,6 @@ class HomeServer(object):
|
||||||
|
|
||||||
DEPENDENCIES = [
|
DEPENDENCIES = [
|
||||||
"http_client",
|
"http_client",
|
||||||
"db_pool",
|
|
||||||
"federation_client",
|
"federation_client",
|
||||||
"federation_server",
|
"federation_server",
|
||||||
"handlers",
|
"handlers",
|
||||||
|
@ -233,12 +230,6 @@ class HomeServer(object):
|
||||||
self.admin_redaction_ratelimiter = Ratelimiter()
|
self.admin_redaction_ratelimiter = Ratelimiter()
|
||||||
self.registration_ratelimiter = Ratelimiter()
|
self.registration_ratelimiter = Ratelimiter()
|
||||||
|
|
||||||
self.database_engine = create_engine(config.database_config)
|
|
||||||
config.database_config.setdefault("args", {})[
|
|
||||||
"cp_openfun"
|
|
||||||
] = self.database_engine.on_new_connection
|
|
||||||
self.db_config = config.database_config
|
|
||||||
|
|
||||||
self.datastores = None
|
self.datastores = None
|
||||||
|
|
||||||
# Other kwargs are explicit dependencies
|
# Other kwargs are explicit dependencies
|
||||||
|
@ -247,10 +238,8 @@ class HomeServer(object):
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
with self.get_db_conn() as conn:
|
|
||||||
self.datastores = DataStores(self.DATASTORE_CLASS, conn, self)
|
|
||||||
conn.commit()
|
|
||||||
self.start_time = int(self.get_clock().time())
|
self.start_time = int(self.get_clock().time())
|
||||||
|
self.datastores = DataStores(self.DATASTORE_CLASS, self)
|
||||||
logger.info("Finished setting up.")
|
logger.info("Finished setting up.")
|
||||||
|
|
||||||
def setup_master(self):
|
def setup_master(self):
|
||||||
|
@ -284,6 +273,9 @@ class HomeServer(object):
|
||||||
def get_datastore(self):
|
def get_datastore(self):
|
||||||
return self.datastores.main
|
return self.datastores.main
|
||||||
|
|
||||||
|
def get_datastores(self):
|
||||||
|
return self.datastores
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return self.config
|
return self.config
|
||||||
|
|
||||||
|
@ -433,31 +425,6 @@ class HomeServer(object):
|
||||||
)
|
)
|
||||||
return MatrixFederationHttpClient(self, tls_client_options_factory)
|
return MatrixFederationHttpClient(self, tls_client_options_factory)
|
||||||
|
|
||||||
def build_db_pool(self):
|
|
||||||
name = self.db_config["name"]
|
|
||||||
|
|
||||||
return adbapi.ConnectionPool(
|
|
||||||
name, cp_reactor=self.get_reactor(), **self.db_config.get("args", {})
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
"""Makes a new connection to the database, skipping the db pool
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Connection: a connection object implementing the PEP-249 spec
|
|
||||||
"""
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v
|
|
||||||
for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def build_media_repository_resource(self):
|
def build_media_repository_resource(self):
|
||||||
# build the media repo resource. This indirects through the HomeServer
|
# build the media repo resource. This indirects through the HomeServer
|
||||||
# to ensure that we only have a single instance of
|
# to ensure that we only have a single instance of
|
||||||
|
|
|
@ -40,7 +40,7 @@ class SQLBaseStore(object):
|
||||||
def __init__(self, database: Database, db_conn, hs):
|
def __init__(self, database: Database, db_conn, hs):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self.database_engine = hs.database_engine
|
self.database_engine = database.engine
|
||||||
self.db = database
|
self.db = database
|
||||||
self.rand = random.SystemRandom()
|
self.rand = random.SystemRandom()
|
||||||
|
|
||||||
|
|
|
@ -13,24 +13,50 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.storage.database import Database
|
import logging
|
||||||
|
|
||||||
|
from synapse.storage.database import Database, make_conn
|
||||||
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.storage.prepare_database import prepare_database
|
from synapse.storage.prepare_database import prepare_database
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DataStores(object):
|
class DataStores(object):
|
||||||
"""The various data stores.
|
"""The various data stores.
|
||||||
|
|
||||||
These are low level interfaces to physical databases.
|
These are low level interfaces to physical databases.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
main (DataStore)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, main_store_class, db_conn, hs):
|
def __init__(self, main_store_class, hs):
|
||||||
# Note we pass in the main store class here as workers use a different main
|
# Note we pass in the main store class here as workers use a different main
|
||||||
# store.
|
# store.
|
||||||
database = Database(hs)
|
|
||||||
|
|
||||||
# Check that db is correctly configured.
|
self.databases = []
|
||||||
database.engine.check_database(db_conn.cursor())
|
|
||||||
|
|
||||||
prepare_database(db_conn, database.engine, config=hs.config)
|
for database_config in hs.config.database.databases:
|
||||||
|
db_name = database_config.name
|
||||||
|
engine = create_engine(database_config.config)
|
||||||
|
|
||||||
|
with make_conn(database_config, engine) as db_conn:
|
||||||
|
logger.info("Preparing database %r...", db_name)
|
||||||
|
|
||||||
|
engine.check_database(db_conn.cursor())
|
||||||
|
prepare_database(
|
||||||
|
db_conn, engine, hs.config, data_stores=database_config.data_stores,
|
||||||
|
)
|
||||||
|
|
||||||
|
database = Database(hs, database_config, engine)
|
||||||
|
|
||||||
|
if "main" in database_config.data_stores:
|
||||||
|
logger.info("Starting 'main' data store")
|
||||||
self.main = main_store_class(database, db_conn, hs)
|
self.main = main_store_class(database, db_conn, hs)
|
||||||
|
|
||||||
|
db_conn.commit()
|
||||||
|
|
||||||
|
self.databases.append(database)
|
||||||
|
|
||||||
|
logger.info("Database %r prepared", db_name)
|
||||||
|
|
|
@ -412,7 +412,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
|
||||||
def _update_client_ips_batch(self):
|
def _update_client_ips_batch(self):
|
||||||
|
|
||||||
# If the DB pool has already terminated, don't try updating
|
# If the DB pool has already terminated, don't try updating
|
||||||
if not self.hs.get_db_pool().running:
|
if not self.db.is_running():
|
||||||
return
|
return
|
||||||
|
|
||||||
to_update = self._batch_row_update
|
to_update = self._batch_row_update
|
||||||
|
|
|
@ -24,9 +24,11 @@ from six.moves import intern, range
|
||||||
|
|
||||||
from prometheus_client import Histogram
|
from prometheus_client import Histogram
|
||||||
|
|
||||||
|
from twisted.enterprise import adbapi
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
|
from synapse.config.database import DatabaseConnectionConfig
|
||||||
from synapse.logging.context import LoggingContext, make_deferred_yieldable
|
from synapse.logging.context import LoggingContext, make_deferred_yieldable
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage.background_updates import BackgroundUpdater
|
from synapse.storage.background_updates import BackgroundUpdater
|
||||||
|
@ -74,6 +76,37 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def make_pool(
|
||||||
|
reactor, db_config: DatabaseConnectionConfig, engine
|
||||||
|
) -> adbapi.ConnectionPool:
|
||||||
|
"""Get the connection pool for the database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return adbapi.ConnectionPool(
|
||||||
|
db_config.config["name"],
|
||||||
|
cp_reactor=reactor,
|
||||||
|
cp_openfun=engine.on_new_connection,
|
||||||
|
**db_config.config.get("args", {})
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_conn(db_config: DatabaseConnectionConfig, engine):
|
||||||
|
"""Make a new connection to the database and return it.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Connection
|
||||||
|
"""
|
||||||
|
|
||||||
|
db_params = {
|
||||||
|
k: v
|
||||||
|
for k, v in db_config.config.get("args", {}).items()
|
||||||
|
if not k.startswith("cp_")
|
||||||
|
}
|
||||||
|
db_conn = engine.module.connect(**db_params)
|
||||||
|
engine.on_new_connection(db_conn)
|
||||||
|
return db_conn
|
||||||
|
|
||||||
|
|
||||||
class LoggingTransaction(object):
|
class LoggingTransaction(object):
|
||||||
"""An object that almost-transparently proxies for the 'txn' object
|
"""An object that almost-transparently proxies for the 'txn' object
|
||||||
passed to the constructor. Adds logging and metrics to the .execute()
|
passed to the constructor. Adds logging and metrics to the .execute()
|
||||||
|
@ -218,10 +251,11 @@ class Database(object):
|
||||||
|
|
||||||
_TXN_ID = 0
|
_TXN_ID = 0
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self._db_pool = hs.get_db_pool()
|
self._database_config = database_config
|
||||||
|
self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
|
||||||
|
|
||||||
self.updates = BackgroundUpdater(hs, self)
|
self.updates = BackgroundUpdater(hs, self)
|
||||||
|
|
||||||
|
@ -234,7 +268,7 @@ class Database(object):
|
||||||
# to watch it
|
# to watch it
|
||||||
self._txn_perf_counters = PerformanceCounters()
|
self._txn_perf_counters = PerformanceCounters()
|
||||||
|
|
||||||
self.engine = hs.database_engine
|
self.engine = engine
|
||||||
|
|
||||||
# A set of tables that are not safe to use native upserts in.
|
# A set of tables that are not safe to use native upserts in.
|
||||||
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
|
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
|
||||||
|
@ -255,6 +289,11 @@ class Database(object):
|
||||||
self._check_safe_to_upsert,
|
self._check_safe_to_upsert,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def is_running(self):
|
||||||
|
"""Is the database pool currently running
|
||||||
|
"""
|
||||||
|
return self._db_pool.running
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_safe_to_upsert(self):
|
def _check_safe_to_upsert(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -16,8 +16,6 @@
|
||||||
import struct
|
import struct
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from synapse.storage.prepare_database import prepare_database
|
|
||||||
|
|
||||||
|
|
||||||
class Sqlite3Engine(object):
|
class Sqlite3Engine(object):
|
||||||
single_threaded = True
|
single_threaded = True
|
||||||
|
@ -25,6 +23,9 @@ class Sqlite3Engine(object):
|
||||||
def __init__(self, database_module, database_config):
|
def __init__(self, database_module, database_config):
|
||||||
self.module = database_module
|
self.module = database_module
|
||||||
|
|
||||||
|
database = database_config.get("args", {}).get("database")
|
||||||
|
self._is_in_memory = database in (None, ":memory:",)
|
||||||
|
|
||||||
# The current max state_group, or None if we haven't looked
|
# The current max state_group, or None if we haven't looked
|
||||||
# in the DB yet.
|
# in the DB yet.
|
||||||
self._current_state_group_id = None
|
self._current_state_group_id = None
|
||||||
|
@ -59,7 +60,16 @@ class Sqlite3Engine(object):
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
def on_new_connection(self, db_conn):
|
def on_new_connection(self, db_conn):
|
||||||
|
|
||||||
|
# We need to import here to avoid an import loop.
|
||||||
|
from synapse.storage.prepare_database import prepare_database
|
||||||
|
|
||||||
|
if self._is_in_memory:
|
||||||
|
# In memory databases need to be rebuilt each time. Ideally we'd
|
||||||
|
# reuse the same connection as we do when starting up, but that
|
||||||
|
# would involve using adbapi before we have started the reactor.
|
||||||
prepare_database(db_conn, self, config=None)
|
prepare_database(db_conn, self, config=None)
|
||||||
|
|
||||||
db_conn.create_function("rank", 1, _rank)
|
db_conn.create_function("rank", 1, _rank)
|
||||||
|
|
||||||
def is_deadlock(self, error):
|
def is_deadlock(self, error):
|
||||||
|
|
|
@ -41,7 +41,7 @@ class UpgradeDatabaseException(PrepareDatabaseException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def prepare_database(db_conn, database_engine, config):
|
def prepare_database(db_conn, database_engine, config, data_stores=["main"]):
|
||||||
"""Prepares a database for usage. Will either create all necessary tables
|
"""Prepares a database for usage. Will either create all necessary tables
|
||||||
or upgrade from an older schema version.
|
or upgrade from an older schema version.
|
||||||
|
|
||||||
|
@ -54,11 +54,10 @@ def prepare_database(db_conn, database_engine, config):
|
||||||
config (synapse.config.homeserver.HomeServerConfig|None):
|
config (synapse.config.homeserver.HomeServerConfig|None):
|
||||||
application config, or None if we are connecting to an existing
|
application config, or None if we are connecting to an existing
|
||||||
database which we expect to be configured already
|
database which we expect to be configured already
|
||||||
|
data_stores (list[str]): The name of the data stores that will be used
|
||||||
|
with this database. Defaults to all data stores.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# For now we only have the one datastore.
|
|
||||||
data_stores = ["main"]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cur = db_conn.cursor()
|
cur = db_conn.cursor()
|
||||||
version_info = _get_or_create_schema_state(cur, database_engine)
|
version_info = _get_or_create_schema_state(cur, database_engine)
|
||||||
|
|
|
@ -64,9 +64,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
mock_federation_client = Mock(spec=["put_json"])
|
mock_federation_client = Mock(spec=["put_json"])
|
||||||
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
|
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
|
||||||
|
|
||||||
hs = self.setup_test_homeserver(
|
datastores = Mock()
|
||||||
datastore=(
|
datastores.main = Mock(
|
||||||
Mock(
|
|
||||||
spec=[
|
spec=[
|
||||||
# Bits that Federation needs
|
# Bits that Federation needs
|
||||||
"prep_send_transaction",
|
"prep_send_transaction",
|
||||||
|
@ -74,18 +73,20 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
"get_received_txn_response",
|
"get_received_txn_response",
|
||||||
"set_received_txn_response",
|
"set_received_txn_response",
|
||||||
"get_destination_retry_timings",
|
"get_destination_retry_timings",
|
||||||
"get_device_updates_by_remote",
|
"get_devices_by_remote",
|
||||||
# Bits that user_directory needs
|
# Bits that user_directory needs
|
||||||
"get_user_directory_stream_pos",
|
"get_user_directory_stream_pos",
|
||||||
"get_current_state_deltas",
|
"get_current_state_deltas",
|
||||||
|
"get_device_updates_by_remote",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
),
|
|
||||||
notifier=Mock(),
|
hs = self.setup_test_homeserver(
|
||||||
http_client=mock_federation_client,
|
notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring
|
||||||
keyring=mock_keyring,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
hs.datastores = datastores
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
|
|
|
@ -20,7 +20,7 @@ from synapse.replication.tcp.client import (
|
||||||
ReplicationClientHandler,
|
ReplicationClientHandler,
|
||||||
)
|
)
|
||||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||||
from synapse.storage.database import Database
|
from synapse.storage.database import make_conn
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import FakeTransport
|
from tests.server import FakeTransport
|
||||||
|
@ -41,10 +41,12 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
|
|
||||||
|
db_config = hs.config.database.get_single_database()
|
||||||
self.master_store = self.hs.get_datastore()
|
self.master_store = self.hs.get_datastore()
|
||||||
self.storage = hs.get_storage()
|
self.storage = hs.get_storage()
|
||||||
|
database = hs.get_datastores().databases[0]
|
||||||
self.slaved_store = self.STORE_TYPE(
|
self.slaved_store = self.STORE_TYPE(
|
||||||
Database(hs), self.hs.get_db_conn(), self.hs
|
database, make_conn(db_config, database.engine), self.hs
|
||||||
)
|
)
|
||||||
self.event_id = 0
|
self.event_id = 0
|
||||||
|
|
||||||
|
|
|
@ -302,14 +302,15 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
|
||||||
Set up a synchronous test server, driven by the reactor used by
|
Set up a synchronous test server, driven by the reactor used by
|
||||||
the homeserver.
|
the homeserver.
|
||||||
"""
|
"""
|
||||||
d = _sth(cleanup_func, *args, **kwargs).result
|
server = _sth(cleanup_func, *args, **kwargs)
|
||||||
|
|
||||||
if isinstance(d, Failure):
|
database = server.config.database.get_single_database()
|
||||||
d.raiseException()
|
|
||||||
|
|
||||||
# Make the thread pool synchronous.
|
# Make the thread pool synchronous.
|
||||||
clock = d.get_clock()
|
clock = server.get_clock()
|
||||||
pool = d.get_db_pool()
|
|
||||||
|
for database in server.get_datastores().databases:
|
||||||
|
pool = database._db_pool
|
||||||
|
|
||||||
def runWithConnection(func, *args, **kwargs):
|
def runWithConnection(func, *args, **kwargs):
|
||||||
return threads.deferToThreadPool(
|
return threads.deferToThreadPool(
|
||||||
|
@ -331,12 +332,12 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if pool:
|
|
||||||
pool.runWithConnection = runWithConnection
|
pool.runWithConnection = runWithConnection
|
||||||
pool.runInteraction = runInteraction
|
pool.runInteraction = runInteraction
|
||||||
pool.threadpool = ThreadPool(clock._reactor)
|
pool.threadpool = ThreadPool(clock._reactor)
|
||||||
pool.running = True
|
pool.running = True
|
||||||
return d
|
|
||||||
|
return server
|
||||||
|
|
||||||
|
|
||||||
def get_clock():
|
def get_clock():
|
||||||
|
|
|
@ -28,7 +28,7 @@ from synapse.storage.data_stores.main.appservice import (
|
||||||
ApplicationServiceStore,
|
ApplicationServiceStore,
|
||||||
ApplicationServiceTransactionStore,
|
ApplicationServiceTransactionStore,
|
||||||
)
|
)
|
||||||
from synapse.storage.database import Database
|
from synapse.storage.database import Database, make_conn
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.utils import setup_test_homeserver
|
from tests.utils import setup_test_homeserver
|
||||||
|
@ -55,8 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
|
||||||
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
|
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
|
||||||
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
|
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
|
||||||
# must be done after inserts
|
# must be done after inserts
|
||||||
database = Database(hs)
|
database = hs.get_datastores().databases[0]
|
||||||
self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs)
|
self.store = ApplicationServiceStore(
|
||||||
|
database, make_conn(database._database_config, database.engine), hs
|
||||||
|
)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
# TODO: suboptimal that we need to create files for tests!
|
# TODO: suboptimal that we need to create files for tests!
|
||||||
|
@ -111,9 +113,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
||||||
hs.config.event_cache_size = 1
|
hs.config.event_cache_size = 1
|
||||||
hs.config.password_providers = []
|
hs.config.password_providers = []
|
||||||
|
|
||||||
self.db_pool = hs.get_db_pool()
|
|
||||||
self.engine = hs.database_engine
|
|
||||||
|
|
||||||
self.as_list = [
|
self.as_list = [
|
||||||
{"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
|
{"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
|
||||||
{"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"},
|
{"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"},
|
||||||
|
@ -125,8 +124,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.as_yaml_files = []
|
self.as_yaml_files = []
|
||||||
|
|
||||||
database = Database(hs)
|
# We assume there is only one database in these tests
|
||||||
self.store = TestTransactionStore(database, hs.get_db_conn(), hs)
|
database = hs.get_datastores().databases[0]
|
||||||
|
self.db_pool = database._db_pool
|
||||||
|
self.engine = database.engine
|
||||||
|
|
||||||
|
db_config = hs.config.get_single_database()
|
||||||
|
self.store = TestTransactionStore(
|
||||||
|
database, make_conn(db_config, self.engine), hs
|
||||||
|
)
|
||||||
|
|
||||||
def _add_service(self, url, as_token, id):
|
def _add_service(self, url, as_token, id):
|
||||||
as_yaml = dict(
|
as_yaml = dict(
|
||||||
|
@ -419,7 +425,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
||||||
hs.config.event_cache_size = 1
|
hs.config.event_cache_size = 1
|
||||||
hs.config.password_providers = []
|
hs.config.password_providers = []
|
||||||
|
|
||||||
ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
|
database = hs.get_datastores().databases[0]
|
||||||
|
ApplicationServiceStore(
|
||||||
|
database, make_conn(database._database_config, database.engine), hs
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_duplicate_ids(self):
|
def test_duplicate_ids(self):
|
||||||
|
@ -435,7 +444,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
||||||
hs.config.password_providers = []
|
hs.config.password_providers = []
|
||||||
|
|
||||||
with self.assertRaises(ConfigError) as cm:
|
with self.assertRaises(ConfigError) as cm:
|
||||||
ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
|
database = hs.get_datastores().databases[0]
|
||||||
|
ApplicationServiceStore(
|
||||||
|
database, make_conn(database._database_config, database.engine), hs
|
||||||
|
)
|
||||||
|
|
||||||
e = cm.exception
|
e = cm.exception
|
||||||
self.assertIn(f1, str(e))
|
self.assertIn(f1, str(e))
|
||||||
|
@ -456,7 +468,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
||||||
hs.config.password_providers = []
|
hs.config.password_providers = []
|
||||||
|
|
||||||
with self.assertRaises(ConfigError) as cm:
|
with self.assertRaises(ConfigError) as cm:
|
||||||
ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
|
database = hs.get_datastores().databases[0]
|
||||||
|
ApplicationServiceStore(
|
||||||
|
database, make_conn(database._database_config, database.engine), hs
|
||||||
|
)
|
||||||
|
|
||||||
e = cm.exception
|
e = cm.exception
|
||||||
self.assertIn(f1, str(e))
|
self.assertIn(f1, str(e))
|
||||||
|
|
|
@ -52,15 +52,17 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
config = Mock()
|
config = Mock()
|
||||||
config._disable_native_upserts = True
|
config._disable_native_upserts = True
|
||||||
config.event_cache_size = 1
|
config.event_cache_size = 1
|
||||||
config.database_config = {"name": "sqlite3"}
|
hs = TestHomeServer("test", config=config)
|
||||||
engine = create_engine(config.database_config)
|
|
||||||
|
sqlite_config = {"name": "sqlite3"}
|
||||||
|
engine = create_engine(sqlite_config)
|
||||||
fake_engine = Mock(wraps=engine)
|
fake_engine = Mock(wraps=engine)
|
||||||
fake_engine.can_native_upsert = False
|
fake_engine.can_native_upsert = False
|
||||||
hs = TestHomeServer(
|
|
||||||
"test", db_pool=self.db_pool, config=config, database_engine=fake_engine
|
|
||||||
)
|
|
||||||
|
|
||||||
self.datastore = SQLBaseStore(Database(hs), None, hs)
|
db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
|
||||||
|
db._db_pool = self.db_pool
|
||||||
|
|
||||||
|
self.datastore = SQLBaseStore(db, None, hs)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_insert_1col(self):
|
def test_insert_1col(self):
|
||||||
|
|
|
@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
hs = yield setup_test_homeserver(self.addCleanup)
|
hs = yield setup_test_homeserver(self.addCleanup)
|
||||||
self.db_pool = hs.get_db_pool()
|
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ from twisted.internet import defer, reactor
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import CodeMessageException, cs_error
|
from synapse.api.errors import CodeMessageException, cs_error
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
|
from synapse.config.database import DatabaseConnectionConfig
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.config.server import DEFAULT_ROOM_VERSION
|
from synapse.config.server import DEFAULT_ROOM_VERSION
|
||||||
from synapse.federation.transport import server as federation_server
|
from synapse.federation.transport import server as federation_server
|
||||||
|
@ -177,7 +178,6 @@ class TestHomeServer(HomeServer):
|
||||||
DATASTORE_CLASS = DataStore
|
DATASTORE_CLASS = DataStore
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def setup_test_homeserver(
|
def setup_test_homeserver(
|
||||||
cleanup_func,
|
cleanup_func,
|
||||||
name="test",
|
name="test",
|
||||||
|
@ -214,7 +214,7 @@ def setup_test_homeserver(
|
||||||
if USE_POSTGRES_FOR_TESTS:
|
if USE_POSTGRES_FOR_TESTS:
|
||||||
test_db = "synapse_test_%s" % uuid.uuid4().hex
|
test_db = "synapse_test_%s" % uuid.uuid4().hex
|
||||||
|
|
||||||
config.database_config = {
|
database_config = {
|
||||||
"name": "psycopg2",
|
"name": "psycopg2",
|
||||||
"args": {
|
"args": {
|
||||||
"database": test_db,
|
"database": test_db,
|
||||||
|
@ -226,12 +226,15 @@ def setup_test_homeserver(
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
config.database_config = {
|
database_config = {
|
||||||
"name": "sqlite3",
|
"name": "sqlite3",
|
||||||
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
|
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
|
||||||
}
|
}
|
||||||
|
|
||||||
db_engine = create_engine(config.database_config)
|
database = DatabaseConnectionConfig("master", database_config, ["main"])
|
||||||
|
config.database.databases = [database]
|
||||||
|
|
||||||
|
db_engine = create_engine(database.config)
|
||||||
|
|
||||||
# Create the database before we actually try and connect to it, based off
|
# Create the database before we actually try and connect to it, based off
|
||||||
# the template database we generate in setupdb()
|
# the template database we generate in setupdb()
|
||||||
|
@ -251,11 +254,6 @@ def setup_test_homeserver(
|
||||||
cur.close()
|
cur.close()
|
||||||
db_conn.close()
|
db_conn.close()
|
||||||
|
|
||||||
# we need to configure the connection pool to run the on_new_connection
|
|
||||||
# function, so that we can test code that uses custom sqlite functions
|
|
||||||
# (like rank).
|
|
||||||
config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
|
|
||||||
|
|
||||||
if datastore is None:
|
if datastore is None:
|
||||||
hs = homeserverToUse(
|
hs = homeserverToUse(
|
||||||
name,
|
name,
|
||||||
|
@ -267,21 +265,19 @@ def setup_test_homeserver(
|
||||||
**kargs
|
**kargs
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
|
hs.setup()
|
||||||
# date db
|
if homeserverToUse.__name__ == "TestHomeServer":
|
||||||
if not isinstance(db_engine, PostgresEngine):
|
hs.setup_master()
|
||||||
db_conn = hs.get_db_conn()
|
|
||||||
yield prepare_database(db_conn, db_engine, config)
|
if isinstance(db_engine, PostgresEngine):
|
||||||
db_conn.commit()
|
database = hs.get_datastores().databases[0]
|
||||||
db_conn.close()
|
|
||||||
|
|
||||||
else:
|
|
||||||
# We need to do cleanup on PostgreSQL
|
# We need to do cleanup on PostgreSQL
|
||||||
def cleanup():
|
def cleanup():
|
||||||
import psycopg2
|
import psycopg2
|
||||||
|
|
||||||
# Close all the db pools
|
# Close all the db pools
|
||||||
hs.get_db_pool().close()
|
database._db_pool.close()
|
||||||
|
|
||||||
dropped = False
|
dropped = False
|
||||||
|
|
||||||
|
@ -320,23 +316,12 @@ def setup_test_homeserver(
|
||||||
# Register the cleanup hook
|
# Register the cleanup hook
|
||||||
cleanup_func(cleanup)
|
cleanup_func(cleanup)
|
||||||
|
|
||||||
hs.setup()
|
|
||||||
if homeserverToUse.__name__ == "TestHomeServer":
|
|
||||||
hs.setup_master()
|
|
||||||
else:
|
else:
|
||||||
# If we have been given an explicit datastore we probably want to mock
|
|
||||||
# out the DataStores somehow too. This all feels a bit wrong, but then
|
|
||||||
# mocking the stores feels wrong too.
|
|
||||||
datastores = Mock(datastore=datastore)
|
|
||||||
|
|
||||||
hs = homeserverToUse(
|
hs = homeserverToUse(
|
||||||
name,
|
name,
|
||||||
db_pool=None,
|
|
||||||
datastore=datastore,
|
datastore=datastore,
|
||||||
datastores=datastores,
|
|
||||||
config=config,
|
config=config,
|
||||||
version_string="Synapse/tests",
|
version_string="Synapse/tests",
|
||||||
database_engine=db_engine,
|
|
||||||
tls_server_context_factory=Mock(),
|
tls_server_context_factory=Mock(),
|
||||||
tls_client_options_factory=Mock(),
|
tls_client_options_factory=Mock(),
|
||||||
reactor=reactor,
|
reactor=reactor,
|
||||||
|
|
Loading…
Reference in a new issue