mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-28 15:08:49 +03:00
DB schema interface for password auth providers
Provide an interface by which password auth providers can register db schema files to be run at startup
This commit is contained in:
parent
c31a7c3ff6
commit
1650eb5847
3 changed files with 89 additions and 0 deletions
|
@ -37,3 +37,15 @@ Password auth provider classes must provide the following methods:
|
||||||
|
|
||||||
The method should return a Twisted ``Deferred`` object, which resolves to
|
The method should return a Twisted ``Deferred`` object, which resolves to
|
||||||
``True`` if authentication is successful, and ``False`` if not.
|
``True`` if authentication is successful, and ``False`` if not.
|
||||||
|
|
||||||
|
Optional methods
|
||||||
|
----------------
|
||||||
|
|
||||||
|
Password provider classes may optionally provide the following methods.
|
||||||
|
|
||||||
|
*class* ``SomeProvider.get_db_schema_files()``
|
||||||
|
|
||||||
|
This method, if implemented, should return an Iterable of ``(name,
|
||||||
|
stream)`` pairs of database schema files. Each file is applied in turn at
|
||||||
|
initialisation, and a record is then made in the database so that it is
|
||||||
|
not re-applied on the next start.
|
||||||
|
|
|
@ -44,6 +44,13 @@ def prepare_database(db_conn, database_engine, config):
|
||||||
|
|
||||||
If `config` is None then prepare_database will assert that no upgrade is
|
If `config` is None then prepare_database will assert that no upgrade is
|
||||||
necessary, *or* will create a fresh database if the database is empty.
|
necessary, *or* will create a fresh database if the database is empty.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_conn:
|
||||||
|
database_engine:
|
||||||
|
config (synapse.config.homeserver.HomeServerConfig|None):
|
||||||
|
application config, or None if we are connecting to an existing
|
||||||
|
database which we expect to be configured already
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
cur = db_conn.cursor()
|
cur = db_conn.cursor()
|
||||||
|
@ -64,6 +71,10 @@ def prepare_database(db_conn, database_engine, config):
|
||||||
else:
|
else:
|
||||||
_setup_new_database(cur, database_engine)
|
_setup_new_database(cur, database_engine)
|
||||||
|
|
||||||
|
# check if any of our configured dynamic modules want a database
|
||||||
|
if config is not None:
|
||||||
|
_apply_module_schemas(cur, database_engine, config)
|
||||||
|
|
||||||
cur.close()
|
cur.close()
|
||||||
db_conn.commit()
|
db_conn.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -283,6 +294,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_module_schemas(txn, database_engine, config):
|
||||||
|
"""Apply the module schemas for the dynamic modules, if any
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cur: database cursor
|
||||||
|
database_engine: synapse database engine class
|
||||||
|
config (synapse.config.homeserver.HomeServerConfig):
|
||||||
|
application config
|
||||||
|
"""
|
||||||
|
for (mod, _config) in config.password_providers:
|
||||||
|
if not hasattr(mod, 'get_db_schema_files'):
|
||||||
|
continue
|
||||||
|
modname = ".".join((mod.__module__, mod.__name__))
|
||||||
|
_apply_module_schema_files(
|
||||||
|
txn, database_engine, modname, mod.get_db_schema_files(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
|
||||||
|
"""Apply the module schemas for a single module
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cur: database cursor
|
||||||
|
database_engine: synapse database engine class
|
||||||
|
modname (str): fully qualified name of the module
|
||||||
|
names_and_streams (Iterable[(str, file)]): the names and streams of
|
||||||
|
schemas to be applied
|
||||||
|
"""
|
||||||
|
cur.execute(
|
||||||
|
database_engine.convert_param_style(
|
||||||
|
"SELECT file FROM applied_module_schemas WHERE module_name = ?"
|
||||||
|
),
|
||||||
|
(modname,)
|
||||||
|
)
|
||||||
|
applied_deltas = set(d for d, in cur)
|
||||||
|
for (name, stream) in names_and_streams:
|
||||||
|
if name in applied_deltas:
|
||||||
|
continue
|
||||||
|
|
||||||
|
root_name, ext = os.path.splitext(name)
|
||||||
|
if ext != '.sql':
|
||||||
|
raise PrepareDatabaseException(
|
||||||
|
"only .sql files are currently supported for module schemas",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("applying schema %s for %s", name, modname)
|
||||||
|
for statement in get_statements(stream):
|
||||||
|
cur.execute(statement)
|
||||||
|
|
||||||
|
# Mark as done.
|
||||||
|
cur.execute(
|
||||||
|
database_engine.convert_param_style(
|
||||||
|
"INSERT INTO applied_module_schemas (module_name, file)"
|
||||||
|
" VALUES (?,?)",
|
||||||
|
),
|
||||||
|
(modname, name)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_statements(f):
|
def get_statements(f):
|
||||||
statement_buffer = ""
|
statement_buffer = ""
|
||||||
in_comment = False # If we're in a /* ... */ style comment
|
in_comment = False # If we're in a /* ... */ style comment
|
||||||
|
|
|
@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas(
|
||||||
file TEXT NOT NULL,
|
file TEXT NOT NULL,
|
||||||
UNIQUE(version, file)
|
UNIQUE(version, file)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
-- a list of schema files we have loaded on behalf of dynamic modules
|
||||||
|
CREATE TABLE IF NOT EXISTS applied_module_schemas(
|
||||||
|
module_name TEXT NOT NULL,
|
||||||
|
file TEXT NOT NULL,
|
||||||
|
UNIQUE(module_name, file)
|
||||||
|
);
|
||||||
|
|
Loading…
Reference in a new issue