diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index e4aa386a2d..54200d621d 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -997,10 +997,10 @@ class RegistrationStore( client_secret/medium/(address|session_id) combo Args: - medium (str): The medium of the 3PID - address (str): The address of the 3PID - sid (str): The ID of the validation session - client_secret (str): A unique string provided by the client to + medium (str|None): The medium of the 3PID + address (str|None): The address of the 3PID + sid (str|None): The ID of the validation session + client_secret (str|None): A unique string provided by the client to help identify this validation attempt validated (bool|None): Whether sessions should be filtered by whether they have been validated already or not. None to @@ -1014,23 +1014,19 @@ class RegistrationStore( keyvalues = { "medium": medium, "client_secret": client_secret, + "session_id": sid, + "address": address, } - - if sid: - keyvalues["session_id"] = sid - elif address: - keyvalues["address"] = address - else: - raise StoreError(500, "Either address or sid must be provided") + cols_to_return = [ + "session_id", "medium", "address", + "client_secret", "last_send_attempt", "validated_at", + ] def get_threepid_validation_session_txn(txn): - cols_to_return = [ - "session_id", "medium", "address", - "client_secret", "last_send_attempt", "validated_at", - ] - - sql = "SELECT %s FROM threepid_validation_session" % ", ".join(cols_to_return) - sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + sql = "SELECT %s FROM threepid_validation_session WHERE %s" % ( + ", ".join(cols_to_return), + " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)), + ) if validated is not None: sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") @@ -1038,24 +1034,17 @@ class RegistrationStore( sql += " LIMIT 1" txn.execute(sql, list(keyvalues.values())) - row = txn.fetchone() - - if not row: + rows = self.cursor_to_dict(txn) + if not rows: return None - # Convert the resulting row to a dictionary - ret = {} - for i in range(len(cols_to_return)): - ret[cols_to_return[i]] = row[i] - - return ret + return rows[0] return self.runInteraction( "get_threepid_validation_session", get_threepid_validation_session_txn, ) - @defer.inlineCallbacks def validate_threepid_session( self, session_id, @@ -1077,58 +1066,65 @@ class RegistrationStore( deferred str|None: A str representing a link to redirect the user to if there is one. """ - row = yield self._simple_select_one( - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - retcols=["client_secret", "validated_at"], - desc="validate_threepid_session_select_session", - allow_none=True, - ) - - if not row: - raise ThreepidValidationError(400, "Unknown session_id") - retrieved_client_secret = row["client_secret"] - validated_at = row["validated_at"] - - if validated_at: - raise ThreepidValidationError( - 400, "This session has already been validated", - ) - if retrieved_client_secret != client_secret: - raise ThreepidValidationError( - 400, "This client_secret does not match the provided session_id", + # Insert everything into a transaction in order to run atomically + def validate_threepid_session_txn(txn): + row = self._simple_select_one_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + retcols=["client_secret", "validated_at"], + allow_none=True, ) - row = yield self._simple_select_one( - table="threepid_validation_token", - keyvalues={"session_id": session_id, "token": token}, - retcols=["expires", "next_link"], - desc="validate_threepid_session_select_token", - allow_none=True, - ) + if not row: + raise ThreepidValidationError(400, "Unknown session_id") + retrieved_client_secret = row["client_secret"] + validated_at = row["validated_at"] - if not row: - raise ThreepidValidationError( - 400, "Validation token not found or has expired", - ) - expires = row["expires"] - next_link = row["next_link"] + if validated_at: + raise ThreepidValidationError( + 400, "This session has already been validated", + ) + if retrieved_client_secret != client_secret: + raise ThreepidValidationError( + 400, "This client_secret does not match the provided session_id", + ) - if expires <= current_ts: - raise ThreepidValidationError( - 400, "This token has expired. Please request a new one", + row = self._simple_select_one_txn( + txn, + table="threepid_validation_token", + keyvalues={"session_id": session_id, "token": token}, + retcols=["expires", "next_link"], + allow_none=True, ) - # Looks good. Validate the session - yield self._simple_update( - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - updatevalues={"validated_at": self.clock.time_msec()}, - desc="validate_threepid_session_update", - ) + if not row: + raise ThreepidValidationError( + 400, "Validation token not found or has expired", + ) + expires = row["expires"] + next_link = row["next_link"] + + if expires <= current_ts: + raise ThreepidValidationError( + 400, "This token has expired. Please request a new one", + ) + + # Looks good. Validate the session + self._simple_update_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + updatevalues={"validated_at": self.clock.time_msec()}, + ) + + return next_link # Return next_link if it exists - defer.returnValue(next_link) + return self.runInteraction( + "validate_threepid_session_txn", + validate_threepid_session_txn, + ) def upsert_threepid_validation_session( self, @@ -1201,7 +1197,6 @@ class RegistrationStore( DELETE FROM threepid_validation_token WHERE expires < ? """ - return txn.execute(sql, (ts,)) return self.runInteraction(