diff --git a/vector/src/main/java/im/vector/app/features/onboarding/OnboardingViewModel.kt b/vector/src/main/java/im/vector/app/features/onboarding/OnboardingViewModel.kt index 54aea0185c..e5bbf39127 100644 --- a/vector/src/main/java/im/vector/app/features/onboarding/OnboardingViewModel.kt +++ b/vector/src/main/java/im/vector/app/features/onboarding/OnboardingViewModel.kt @@ -147,7 +147,7 @@ class OnboardingViewModel @AssistedInject constructor( is OnboardingAction.WebLoginSuccess -> handleWebLoginSuccess(action) is OnboardingAction.ResetPassword -> handleResetPassword(action) is OnboardingAction.ResetPasswordMailConfirmed -> handleResetPasswordMailConfirmed() - is OnboardingAction.PostRegisterAction -> handleRegisterAction(action.registerAction) + is OnboardingAction.PostRegisterAction -> handleRegisterAction(action.registerAction, ::emitFlowResultViewEvent) is OnboardingAction.ResetAction -> handleResetAction(action) is OnboardingAction.UserAcceptCertificate -> handleUserAcceptCertificate(action) OnboardingAction.ClearHomeServerHistory -> handleClearHomeServerHistory() @@ -211,7 +211,7 @@ class OnboardingViewModel @AssistedInject constructor( ?.let { it.copy(allowedFingerprints = it.allowedFingerprints + action.fingerprint) } ?.let { startAuthenticationFlow(it) } } - is OnboardingAction.LoginOrRegister -> + is OnboardingAction.LoginOrRegister -> handleDirectLogin( finalLastAction, HomeServerConnectionConfig.Builder() @@ -220,7 +220,7 @@ class OnboardingViewModel @AssistedInject constructor( .withAllowedFingerPrints(listOf(action.fingerprint)) .build() ) - else -> Unit + else -> Unit } } @@ -255,11 +255,12 @@ class OnboardingViewModel @AssistedInject constructor( } } - private fun handleRegisterAction(action: RegisterAction) { + private fun handleRegisterAction(action: RegisterAction, onNextRegistrationStepAction: (FlowResult) -> Unit) { currentJob = viewModelScope.launch { if (action.hasLoadingState()) { setState { copy(isLoading = true) } } + runCatching { registrationActionHandler.handleRegisterAction(registrationWizard, action) } .fold( onSuccess = { @@ -269,7 +270,7 @@ class OnboardingViewModel @AssistedInject constructor( } else -> when (it) { is RegistrationResult.Success -> onSessionCreated(it.session, isAccountCreated = true) - is RegistrationResult.FlowResponse -> onFlowResponse(it.flowResult) + is RegistrationResult.FlowResponse -> onFlowResponse(it.flowResult, onNextRegistrationStepAction) } } }, @@ -283,13 +284,20 @@ class OnboardingViewModel @AssistedInject constructor( } } + private fun emitFlowResultViewEvent(flowResult: FlowResult) { + _viewEvents.post(OnboardingViewEvents.RegistrationFlowResult(flowResult, isRegistrationStarted)) + } + private fun handleRegisterWith(action: OnboardingAction.Register) { reAuthHelper.data = action.password - handleRegisterAction(RegisterAction.CreateAccount( - action.username, - action.password, - action.initialDeviceName - )) + handleRegisterAction( + RegisterAction.CreateAccount( + action.username, + action.password, + action.initialDeviceName + ), + ::emitFlowResultViewEvent + ) } private fun handleResetAction(action: OnboardingAction.ResetAction) { @@ -344,7 +352,7 @@ class OnboardingViewModel @AssistedInject constructor( } when (action.signMode) { - SignMode.SignUp -> handleRegisterAction(RegisterAction.StartRegistration) + SignMode.SignUp -> handleRegisterAction(RegisterAction.StartRegistration, ::emitFlowResultViewEvent) SignMode.SignIn -> startAuthenticationFlow() SignMode.SignInWithMatrixId -> _viewEvents.post(OnboardingViewEvents.OnSignModeSelected(SignMode.SignInWithMatrixId)) SignMode.Unknown -> Unit @@ -509,18 +517,17 @@ class OnboardingViewModel @AssistedInject constructor( _viewEvents.post(OnboardingViewEvents.OnSignModeSelected(SignMode.SignIn)) } - private fun onFlowResponse(flowResult: FlowResult) { + private fun onFlowResponse(flowResult: FlowResult, onNextRegistrationStepAction: (FlowResult) -> Unit) { // If dummy stage is mandatory, and password is already sent, do the dummy stage now if (isRegistrationStarted && flowResult.missingStages.any { it is Stage.Dummy && it.mandatory }) { - handleRegisterDummy() + handleRegisterDummy(onNextRegistrationStepAction) } else { - // Notify the user - _viewEvents.post(OnboardingViewEvents.RegistrationFlowResult(flowResult, isRegistrationStarted)) + onNextRegistrationStepAction(flowResult) } } - private fun handleRegisterDummy() { - handleRegisterAction(RegisterAction.RegisterDummy) + private fun handleRegisterDummy(onNextRegistrationStepAction: (FlowResult) -> Unit) { + handleRegisterAction(RegisterAction.RegisterDummy, onNextRegistrationStepAction) } private suspend fun onSessionCreated(session: Session, isAccountCreated: Boolean) { @@ -599,19 +606,7 @@ class OnboardingViewModel @AssistedInject constructor( runCatching { startAuthenticationFlowUseCase.execute(homeServerConnectionConfig) }.fold( onSuccess = { - rememberHomeServer(homeServerConnectionConfig.homeServerUri.toString()) - if (it.isHomeserverOutdated) { - _viewEvents.post(OnboardingViewEvents.OutdatedHomeserver) - } - - setState { - copy( - serverType = alignServerTypeAfterSubmission(homeServerConnectionConfig, serverTypeOverride), - selectedHomeserver = it.selectedHomeserver, - isLoading = false, - ) - } - onAuthenticationStartedSuccess() + onAuthenticationStartedSuccess(homeServerConnectionConfig, it, serverTypeOverride) }, onFailure = { setState { copy(isLoading = false) } @@ -621,6 +616,48 @@ class OnboardingViewModel @AssistedInject constructor( } } + private fun onAuthenticationStartedSuccess(config: HomeServerConnectionConfig, authResult: StartAuthenticationFlowUseCase.StartAuthenticationResult, serverTypeOverride: ServerType?) { + rememberHomeServer(config.homeServerUri.toString()) + if (authResult.isHomeserverOutdated) { + _viewEvents.post(OnboardingViewEvents.OutdatedHomeserver) + } + + setState { + copy( + serverType = alignServerTypeAfterSubmission(config, serverTypeOverride), + selectedHomeserver = authResult.selectedHomeserver, + isLoading = false + ) + } + withState { + when (lastAction) { + is OnboardingAction.HomeServerChange.EditHomeServer -> { + when (it.onboardingFlow) { + OnboardingFlow.SignUp -> handleRegisterAction(RegisterAction.StartRegistration) { _ -> + _viewEvents.post(OnboardingViewEvents.OnHomeserverEdited) + } + else -> throw IllegalArgumentException("developer error") + } + } + is OnboardingAction.HomeServerChange.SelectHomeServer -> { + if (it.selectedHomeserver.preferredLoginMode.supportsSignModeScreen()) { + when (it.onboardingFlow) { + OnboardingFlow.SignIn -> handleUpdateSignMode(OnboardingAction.UpdateSignMode(SignMode.SignIn)) + OnboardingFlow.SignUp -> handleUpdateSignMode(OnboardingAction.UpdateSignMode(SignMode.SignUp)) + OnboardingFlow.SignInSignUp, + null -> { + _viewEvents.post(OnboardingViewEvents.OnLoginFlowRetrieved) + } + } + } else { + _viewEvents.post(OnboardingViewEvents.OnLoginFlowRetrieved) + } + } + else -> _viewEvents.post(OnboardingViewEvents.OnLoginFlowRetrieved) + } + } + } + /** * If user has entered https://matrix.org, ensure that server type is ServerType.MatrixOrg * It is also useful to set the value again in the case of a certificate error on matrix.org @@ -633,28 +670,6 @@ class OnboardingViewModel @AssistedInject constructor( } } - private fun onAuthenticationStartedSuccess() { - withState { - when (lastAction) { - is OnboardingAction.HomeServerChange.EditHomeServer -> _viewEvents.post(OnboardingViewEvents.OnHomeserverEdited) - is OnboardingAction.HomeServerChange.SelectHomeServer -> { - if (it.selectedHomeserver.preferredLoginMode.supportsSignModeScreen()) { - when (it.onboardingFlow) { - OnboardingFlow.SignIn -> handleUpdateSignMode(OnboardingAction.UpdateSignMode(SignMode.SignIn)) - OnboardingFlow.SignUp -> handleUpdateSignMode(OnboardingAction.UpdateSignMode(SignMode.SignUp)) - OnboardingFlow.SignInSignUp, - null -> { - _viewEvents.post(OnboardingViewEvents.OnLoginFlowRetrieved) - } - } - } else { - _viewEvents.post(OnboardingViewEvents.OnLoginFlowRetrieved) - } - } - else -> _viewEvents.post(OnboardingViewEvents.OnLoginFlowRetrieved) - } - } - } fun getInitialHomeServerUrl(): String? { return loginConfig?.homeServerUrl