Update db

This commit is contained in:
Andrew Morgan 2019-06-05 12:55:51 +01:00
parent 899219c48c
commit 309943f2ef

View file

@ -997,10 +997,10 @@ class RegistrationStore(
client_secret/medium/(address|session_id) combo client_secret/medium/(address|session_id) combo
Args: Args:
medium (str): The medium of the 3PID medium (str|None): The medium of the 3PID
address (str): The address of the 3PID address (str|None): The address of the 3PID
sid (str): The ID of the validation session sid (str|None): The ID of the validation session
client_secret (str): A unique string provided by the client to client_secret (str|None): A unique string provided by the client to
help identify this validation attempt help identify this validation attempt
validated (bool|None): Whether sessions should be filtered by validated (bool|None): Whether sessions should be filtered by
whether they have been validated already or not. None to whether they have been validated already or not. None to
@ -1014,23 +1014,19 @@ class RegistrationStore(
keyvalues = { keyvalues = {
"medium": medium, "medium": medium,
"client_secret": client_secret, "client_secret": client_secret,
"session_id": sid,
"address": address,
} }
if sid:
keyvalues["session_id"] = sid
elif address:
keyvalues["address"] = address
else:
raise StoreError(500, "Either address or sid must be provided")
def get_threepid_validation_session_txn(txn):
cols_to_return = [ cols_to_return = [
"session_id", "medium", "address", "session_id", "medium", "address",
"client_secret", "last_send_attempt", "validated_at", "client_secret", "last_send_attempt", "validated_at",
] ]
sql = "SELECT %s FROM threepid_validation_session" % ", ".join(cols_to_return) def get_threepid_validation_session_txn(txn):
sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) sql = "SELECT %s FROM threepid_validation_session WHERE %s" % (
", ".join(cols_to_return),
" AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),
)
if validated is not None: if validated is not None:
sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
@ -1038,24 +1034,17 @@ class RegistrationStore(
sql += " LIMIT 1" sql += " LIMIT 1"
txn.execute(sql, list(keyvalues.values())) txn.execute(sql, list(keyvalues.values()))
row = txn.fetchone() rows = self.cursor_to_dict(txn)
if not rows:
if not row:
return None return None
# Convert the resulting row to a dictionary return rows[0]
ret = {}
for i in range(len(cols_to_return)):
ret[cols_to_return[i]] = row[i]
return ret
return self.runInteraction( return self.runInteraction(
"get_threepid_validation_session", "get_threepid_validation_session",
get_threepid_validation_session_txn, get_threepid_validation_session_txn,
) )
@defer.inlineCallbacks
def validate_threepid_session( def validate_threepid_session(
self, self,
session_id, session_id,
@ -1077,11 +1066,13 @@ class RegistrationStore(
deferred str|None: A str representing a link to redirect the user deferred str|None: A str representing a link to redirect the user
to if there is one. to if there is one.
""" """
row = yield self._simple_select_one( # Insert everything into a transaction in order to run atomically
def validate_threepid_session_txn(txn):
row = self._simple_select_one_txn(
txn,
table="threepid_validation_session", table="threepid_validation_session",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
retcols=["client_secret", "validated_at"], retcols=["client_secret", "validated_at"],
desc="validate_threepid_session_select_session",
allow_none=True, allow_none=True,
) )
@ -1099,11 +1090,11 @@ class RegistrationStore(
400, "This client_secret does not match the provided session_id", 400, "This client_secret does not match the provided session_id",
) )
row = yield self._simple_select_one( row = self._simple_select_one_txn(
txn,
table="threepid_validation_token", table="threepid_validation_token",
keyvalues={"session_id": session_id, "token": token}, keyvalues={"session_id": session_id, "token": token},
retcols=["expires", "next_link"], retcols=["expires", "next_link"],
desc="validate_threepid_session_select_token",
allow_none=True, allow_none=True,
) )
@ -1120,15 +1111,20 @@ class RegistrationStore(
) )
# Looks good. Validate the session # Looks good. Validate the session
yield self._simple_update( self._simple_update_txn(
txn,
table="threepid_validation_session", table="threepid_validation_session",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
updatevalues={"validated_at": self.clock.time_msec()}, updatevalues={"validated_at": self.clock.time_msec()},
desc="validate_threepid_session_update",
) )
return next_link
# Return next_link if it exists # Return next_link if it exists
defer.returnValue(next_link) return self.runInteraction(
"validate_threepid_session_txn",
validate_threepid_session_txn,
)
def upsert_threepid_validation_session( def upsert_threepid_validation_session(
self, self,
@ -1201,7 +1197,6 @@ class RegistrationStore(
DELETE FROM threepid_validation_token WHERE DELETE FROM threepid_validation_token WHERE
expires < ? expires < ?
""" """
return txn.execute(sql, (ts,)) return txn.execute(sql, (ts,))
return self.runInteraction( return self.runInteraction(