Update db

This commit is contained in:
Andrew Morgan 2019-06-05 12:55:51 +01:00
parent 899219c48c
commit 309943f2ef

View file

@ -997,10 +997,10 @@ class RegistrationStore(
client_secret/medium/(address|session_id) combo client_secret/medium/(address|session_id) combo
Args: Args:
medium (str): The medium of the 3PID medium (str|None): The medium of the 3PID
address (str): The address of the 3PID address (str|None): The address of the 3PID
sid (str): The ID of the validation session sid (str|None): The ID of the validation session
client_secret (str): A unique string provided by the client to client_secret (str|None): A unique string provided by the client to
help identify this validation attempt help identify this validation attempt
validated (bool|None): Whether sessions should be filtered by validated (bool|None): Whether sessions should be filtered by
whether they have been validated already or not. None to whether they have been validated already or not. None to
@ -1014,23 +1014,19 @@ class RegistrationStore(
keyvalues = { keyvalues = {
"medium": medium, "medium": medium,
"client_secret": client_secret, "client_secret": client_secret,
"session_id": sid,
"address": address,
} }
cols_to_return = [
if sid: "session_id", "medium", "address",
keyvalues["session_id"] = sid "client_secret", "last_send_attempt", "validated_at",
elif address: ]
keyvalues["address"] = address
else:
raise StoreError(500, "Either address or sid must be provided")
def get_threepid_validation_session_txn(txn): def get_threepid_validation_session_txn(txn):
cols_to_return = [ sql = "SELECT %s FROM threepid_validation_session WHERE %s" % (
"session_id", "medium", "address", ", ".join(cols_to_return),
"client_secret", "last_send_attempt", "validated_at", " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),
] )
sql = "SELECT %s FROM threepid_validation_session" % ", ".join(cols_to_return)
sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
if validated is not None: if validated is not None:
sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
@ -1038,24 +1034,17 @@ class RegistrationStore(
sql += " LIMIT 1" sql += " LIMIT 1"
txn.execute(sql, list(keyvalues.values())) txn.execute(sql, list(keyvalues.values()))
row = txn.fetchone() rows = self.cursor_to_dict(txn)
if not rows:
if not row:
return None return None
# Convert the resulting row to a dictionary return rows[0]
ret = {}
for i in range(len(cols_to_return)):
ret[cols_to_return[i]] = row[i]
return ret
return self.runInteraction( return self.runInteraction(
"get_threepid_validation_session", "get_threepid_validation_session",
get_threepid_validation_session_txn, get_threepid_validation_session_txn,
) )
@defer.inlineCallbacks
def validate_threepid_session( def validate_threepid_session(
self, self,
session_id, session_id,
@ -1077,58 +1066,65 @@ class RegistrationStore(
deferred str|None: A str representing a link to redirect the user deferred str|None: A str representing a link to redirect the user
to if there is one. to if there is one.
""" """
row = yield self._simple_select_one( # Insert everything into a transaction in order to run atomically
table="threepid_validation_session", def validate_threepid_session_txn(txn):
keyvalues={"session_id": session_id}, row = self._simple_select_one_txn(
retcols=["client_secret", "validated_at"], txn,
desc="validate_threepid_session_select_session", table="threepid_validation_session",
allow_none=True, keyvalues={"session_id": session_id},
) retcols=["client_secret", "validated_at"],
allow_none=True,
if not row:
raise ThreepidValidationError(400, "Unknown session_id")
retrieved_client_secret = row["client_secret"]
validated_at = row["validated_at"]
if validated_at:
raise ThreepidValidationError(
400, "This session has already been validated",
)
if retrieved_client_secret != client_secret:
raise ThreepidValidationError(
400, "This client_secret does not match the provided session_id",
) )
row = yield self._simple_select_one( if not row:
table="threepid_validation_token", raise ThreepidValidationError(400, "Unknown session_id")
keyvalues={"session_id": session_id, "token": token}, retrieved_client_secret = row["client_secret"]
retcols=["expires", "next_link"], validated_at = row["validated_at"]
desc="validate_threepid_session_select_token",
allow_none=True,
)
if not row: if validated_at:
raise ThreepidValidationError( raise ThreepidValidationError(
400, "Validation token not found or has expired", 400, "This session has already been validated",
) )
expires = row["expires"] if retrieved_client_secret != client_secret:
next_link = row["next_link"] raise ThreepidValidationError(
400, "This client_secret does not match the provided session_id",
)
if expires <= current_ts: row = self._simple_select_one_txn(
raise ThreepidValidationError( txn,
400, "This token has expired. Please request a new one", table="threepid_validation_token",
keyvalues={"session_id": session_id, "token": token},
retcols=["expires", "next_link"],
allow_none=True,
) )
# Looks good. Validate the session if not row:
yield self._simple_update( raise ThreepidValidationError(
table="threepid_validation_session", 400, "Validation token not found or has expired",
keyvalues={"session_id": session_id}, )
updatevalues={"validated_at": self.clock.time_msec()}, expires = row["expires"]
desc="validate_threepid_session_update", next_link = row["next_link"]
)
if expires <= current_ts:
raise ThreepidValidationError(
400, "This token has expired. Please request a new one",
)
# Looks good. Validate the session
self._simple_update_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
updatevalues={"validated_at": self.clock.time_msec()},
)
return next_link
# Return next_link if it exists # Return next_link if it exists
defer.returnValue(next_link) return self.runInteraction(
"validate_threepid_session_txn",
validate_threepid_session_txn,
)
def upsert_threepid_validation_session( def upsert_threepid_validation_session(
self, self,
@ -1201,7 +1197,6 @@ class RegistrationStore(
DELETE FROM threepid_validation_token WHERE DELETE FROM threepid_validation_token WHERE
expires < ? expires < ?
""" """
return txn.execute(sql, (ts,)) return txn.execute(sql, (ts,))
return self.runInteraction( return self.runInteraction(