From 774d561540a7b236e1ac0581ec37e95ab5412840 Mon Sep 17 00:00:00 2001
From: jano3 <jano@uafrica.com>
Date: Tue, 21 Feb 2023 15:56:55 +0200
Subject: [PATCH] Update ConfirmPasswordReset to be able to handle both forgot
 password and user confirmation

---
 cognito/cognito.go | 49 +++++++++++++++++++++++++++++++++++++++++++++-
 errors/errors.go   | 11 +++++++++++
 2 files changed, 59 insertions(+), 1 deletion(-)

diff --git a/cognito/cognito.go b/cognito/cognito.go
index 079b200..d121591 100644
--- a/cognito/cognito.go
+++ b/cognito/cognito.go
@@ -2,6 +2,8 @@ package cognito
 
 import (
 	"fmt"
+	"gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/errors"
+	"gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/utils"
 	"math/rand"
 	"strings"
 
@@ -90,7 +92,7 @@ func SetUserPassword(pool string, username string, password string) (*cognitoide
 	return output, err
 }
 
-func ConfirmPasswordReset(appClientID string, username string, password string, confirmationCode string) (*cognitoidentityprovider.ConfirmForgotPasswordOutput, error) {
+func confirmForgotPassword(appClientID string, username string, password string, confirmationCode string) (*cognitoidentityprovider.ConfirmForgotPasswordOutput, error) {
 	input := cognitoidentityprovider.ConfirmForgotPasswordInput{
 		ClientId:         &appClientID,
 		ConfirmationCode: &confirmationCode,
@@ -102,6 +104,51 @@ func ConfirmPasswordReset(appClientID string, username string, password string,
 	return output, err
 }
 
+func confirmPasswordReset(appClientID string, username string, password string, initiateAuthOutput *cognitoidentityprovider.InitiateAuthOutput) (*cognitoidentityprovider.RespondToAuthChallengeOutput, error) {
+	// Respond to the Auth challenge to change the user's password
+	authChallengeParameters := map[string]*string{
+		"USERNAME":     utils.PointerValue(username),
+		"NEW_PASSWORD": utils.PointerValue(password),
+	}
+	respondToAuthChallengeInput := cognitoidentityprovider.RespondToAuthChallengeInput{
+		ChallengeName:      initiateAuthOutput.ChallengeName,
+		ChallengeResponses: authChallengeParameters,
+		ClientId:           &appClientID,
+		Session:            initiateAuthOutput.Session,
+	}
+	output, err := CognitoService.RespondToAuthChallenge(&respondToAuthChallengeInput)
+	logs.Info("output", output)
+	return output, err
+}
+
+// ConfirmPasswordReset initiates a Cognito auth for the user, and based on the output either updates the user's password,
+// or performs a forgot password confirmation.
+func ConfirmPasswordReset(appClientID string, username string, password string, confirmationCode string) (interface{}, error) {
+	// Initiate an auth for the user to see if a password reset or
+	authParameters := map[string]*string{
+		"USERNAME": utils.PointerValue(username),
+		"PASSWORD": utils.PointerValue(confirmationCode),
+	}
+	initiateAuthInput := cognitoidentityprovider.InitiateAuthInput{
+		AuthFlow:       utils.PointerValue(cognitoidentityprovider.ExplicitAuthFlowsTypeUserPasswordAuth),
+		AuthParameters: authParameters,
+		ClientId:       &appClientID,
+	}
+	res, err := CognitoService.InitiateAuth(&initiateAuthInput)
+	if err != nil {
+		if errors.AWSErrorExceptionCode(err) == cognitoidentityprovider.ErrCodePasswordResetRequiredException {
+			// Not a user verification - perform forgot password confirmation
+			return confirmForgotPassword(appClientID, username, password, confirmationCode)
+		}
+		return nil, err
+	}
+	if utils.Unwrap(res.ChallengeName) == cognitoidentityprovider.ChallengeNameTypeNewPasswordRequired {
+		return confirmPasswordReset(appClientID, username, password, res)
+	}
+
+	return nil, errors.New("User state not correct for confirmation. Please contact support.")
+}
+
 // FOR API LOGS
 
 func DetermineAuthType(identity events.APIGatewayRequestIdentity) *string {
diff --git a/errors/errors.go b/errors/errors.go
index d302aa6..3c6ab5e 100644
--- a/errors/errors.go
+++ b/errors/errors.go
@@ -112,6 +112,17 @@ func HTTPWithError(code int, err error) error {
 	return wrappedErr
 }
 
+func AWSErrorExceptionCode(err error) string {
+	if err == nil {
+		return ""
+	}
+
+	if awsError, ok := err.(awserr.Error); ok {
+		return awsError.Code()
+	}
+	return ""
+}
+
 func AWSErrorWithoutExceptionCode(err error) error {
 	if err == nil {
 		return nil
-- 
GitLab