diff --git a/auth/common.go b/auth/common.go index d411101a2d2e96b1489bf5bbab3e4d1728e16190..c9a37eccab18bc1bd1be9661cba91183037e93ad 100644 --- a/auth/common.go +++ b/auth/common.go @@ -1,16 +1,14 @@ package auth import ( + "gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/handler_utils" "golang.org/x/crypto/bcrypt" "strings" ) // GetBearerTokenFromHeaders checks if a bearer token is passed as part of the Authorization header and returns that key func GetBearerTokenFromHeaders(headers map[string]string) string { - headerValue := headers["authorization"] - if headerValue == "" { - headerValue = headers["Authorization"] - } + headerValue := handler_utils.FindHeaderValue(headers, "authorization") if strings.HasPrefix(strings.ToLower(headerValue), "bearer ") { headerValues := strings.Split(headerValue, " ") return strings.TrimSpace(headerValues[1]) diff --git a/auth/jwt.go b/auth/jwt.go index 2f9c35ef01dd5664fc0f7550cc4e297ccb620db1..705995f10d58febd74d6512373ab859f79b019b8 100644 --- a/auth/jwt.go +++ b/auth/jwt.go @@ -11,13 +11,19 @@ import ( type JsonWebToken struct { UserID string `json:"user_id"` - Password string `json:"password"` ProviderID int64 `json:"provider_id,omitempty"` ExpiryDate time.Time `json:"expiry_date"` } +// GenerateJWTWithSessionToken first signs the session token with the secret key, then takes the payload and generates a +// signed JWT using the resulting signed session token +func GenerateJWTWithSessionToken(payload JsonWebToken, secretKey string, sessionToken string) (string, error) { + signedSessionToken := SignSessionTokenWithKey(secretKey, sessionToken) + return GenerateJWT(payload, signedSessionToken) +} + // GenerateJWT takes the payload and generates a signed JWT using the provided secret -func GenerateJWT(payload JsonWebToken, secret []byte) (string, error) { +func GenerateJWT(payload JsonWebToken, secretKey string) (string, error) { // Convert the JsonWebToken to a map[string]interface{} tokenBytes, err := json.Marshal(payload) if err != nil { @@ -32,7 +38,7 @@ func GenerateJWT(payload JsonWebToken, secret []byte) (string, error) { // Create the signed token token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims(tokenMap)) - tokenString, err := token.SignedString(secret) + tokenString, err := token.SignedString([]byte(secretKey)) if err != nil { return "", err } @@ -40,25 +46,13 @@ func GenerateJWT(payload JsonWebToken, secret []byte) (string, error) { return tokenString, nil } -// ValidateJWT parses the JWT and validates that it is signed correctly -func ValidateJWT(tokenString string, secret []byte) (JsonWebToken, error) { - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - // Validate the signing method - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, errors.Errorf("unexpected signing method: %v", token.Header["alg"]) - } - - return secret, nil - }) - if err != nil { - return JsonWebToken{}, err - } +func getJsonWebTokenFromTokenClaims(token *jwt.Token, checkValidity bool) (JsonWebToken, error) { if token == nil { return JsonWebToken{}, errors.Error("could not get token from token string") } claims, ok := token.Claims.(jwt.MapClaims) - if !ok || !token.Valid { + if !ok || (checkValidity && token.Valid == false) { return JsonWebToken{}, errors.Error("invalid token") } @@ -74,6 +68,37 @@ func ValidateJWT(tokenString string, secret []byte) (JsonWebToken, error) { return JsonWebToken{}, err } + return jsonWebToken, nil +} + +// ValidateJWTWithSessionToken first signs the session token using the secret key, then parses the JWT and validates +// that it is signed correctly +func ValidateJWTWithSessionToken(jsonWebTokenString string, secretKey string, sessionToken string) (JsonWebToken, error) { + // Sign the session token with the secret key - this prevents the JWT from being used by other sessions + signedSecret := SignSessionTokenWithKey(secretKey, sessionToken) + return ValidateJWT(jsonWebTokenString, signedSecret) +} + +// ValidateJWT parses the JWT and validates that it is signed correctly +func ValidateJWT(jsonWebTokenString string, secretKey string) (JsonWebToken, error) { + // Validate the + token, err := jwt.Parse(jsonWebTokenString, func(token *jwt.Token) (interface{}, error) { + // Validate the signing method + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, errors.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + return []byte(secretKey), nil + }) + if err != nil { + return JsonWebToken{}, err + } + + jsonWebToken, err := getJsonWebTokenFromTokenClaims(token, true) + if err != nil { + return JsonWebToken{}, err + } + // Validate the expiry date if jsonWebToken.ExpiryDate.Before(date_utils.CurrentDate()) { return jsonWebToken, errors.Error("token has expired") @@ -86,8 +111,33 @@ func ValidateJWT(tokenString string, secret []byte) (JsonWebToken, error) { // using the provided encryption key. func LoginWithPassword(password string, hashedPassword string, jsonWebToken JsonWebToken, jwtEncryptionKey string) (string, error) { if PasswordIsCorrect(password, hashedPassword) { - return GenerateJWT(jsonWebToken, []byte(jwtEncryptionKey)) + return GenerateJWT(jsonWebToken, jwtEncryptionKey) + } + + return "", errors.HTTPWithMsg(http.StatusBadRequest, "password is incorrect") +} + +// LoginSessionWithPassword checks that the provided password is correct. If the password is correct, the session token +// is signed using the secret key, and a JWT is returned using the signed session token +func LoginSessionWithPassword(password string, hashedPassword string, jsonWebToken JsonWebToken, secretKey string, sessionToken string) (string, error) { + if PasswordIsCorrect(password, hashedPassword) { + return GenerateJWTWithSessionToken(jsonWebToken, secretKey, sessionToken) } return "", errors.HTTPWithMsg(http.StatusBadRequest, "password is incorrect") } + +// GetUserIDFromJWTWithoutValidation gets the userID from the jsonWebTokenString without validating the signature. +// Successful execution of this function DOES NOT indicate that the JWT is valid in any way. +func GetUserIDFromJWTWithoutValidation(jsonWebTokenString string) string { + token, _, err := jwt.NewParser().ParseUnverified(jsonWebTokenString, jwt.MapClaims{}) + if err != nil { + return "" + } + + jsonWebToken, err := getJsonWebTokenFromTokenClaims(token, false) + if err != nil { + return "" + } + return jsonWebToken.UserID +} diff --git a/auth/session.go b/auth/session.go new file mode 100644 index 0000000000000000000000000000000000000000..c7c81e4a0cf1f0bbef631822438a8c5cc1d55fb3 --- /dev/null +++ b/auth/session.go @@ -0,0 +1,106 @@ +package auth + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "github.com/aws/aws-lambda-go/events" + "github.com/google/uuid" + "github.com/thoas/go-funk" + "gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/date_utils" + "gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/handler_utils" + "time" +) + +type SessionToken struct { + IP string `json:"ip"` + UserAgent string `json:"user_agent"` + TimeCreated time.Time `json:"time_created"` + Token string `json:"session_token"` +} + +// SignSessionTokenWithKey signs the session token string using the secret. +func SignSessionTokenWithKey(secretKey string, sessionTokenString string) string { + signer := hmac.New(sha256.New, []byte(secretKey)) + signer.Write([]byte(sessionTokenString)) + return hex.EncodeToString(signer.Sum(nil)) +} + +// GetSessionTokenString creates a unique session token string from the API request. +func GetSessionTokenString(request events.APIGatewayProxyRequest) (string, error) { + sessionToken := SessionToken{ + IP: request.RequestContext.Identity.SourceIP, + UserAgent: handler_utils.FindHeaderValue(request.Headers, "user-agent"), + TimeCreated: date_utils.CurrentDate().Round(time.Second), + Token: uuid.New().String(), + } + sessionTokenBytes, err := json.Marshal(sessionToken) + return string(sessionTokenBytes), err +} + +// GetSignedSessionTokenString creates a unique session token string from the API request and signs it using the secret. +func GetSignedSessionTokenString(request events.APIGatewayProxyRequest, secretKey string) (string, error) { + sessionTokenBytes, err := GetSessionTokenString(request) + if err != nil { + return "", err + } + + return SignSessionTokenWithKey(secretKey, sessionTokenBytes), nil +} + +// ValidateJWTWithSessionTokens attempts to validate the JWT string by signing each session token using the secret, and +// using the resulting signed session token to validate the JWT. If the JWT can be validated using a session token, the +// JsonWebToken is returned, otherwise nil is returned. +func ValidateJWTWithSessionTokens(jsonWebTokenString string, secretKey string, sessionTokens []string) *JsonWebToken { + // Test each session token to find one that is valid + for _, sessionToken := range sessionTokens { + jsonWebToken, err := ValidateJWTWithSessionToken(jsonWebTokenString, secretKey, sessionToken) + if err == nil { + return &jsonWebToken + } + } + + return nil +} + +// FindAndRemoveCurrentSessionToken attempts to validate the JWT string by signing each session token using the secret, +// and using the resulting signed session token to validate the JWT. If the JWT is successfully validated with one of +// the session tokens, the session token is removed from the slice, otherwise the original session token +// slice is returned. +func FindAndRemoveCurrentSessionToken(jsonWebTokenString string, secretKey string, sessionTokens []string) (string, []string) { + // Test each session token to find one that is valid + for _, sessionToken := range sessionTokens { + _, err := ValidateJWTWithSessionToken(jsonWebTokenString, secretKey, sessionToken) + if err == nil { + // Remove this session token from the slice + updatedSessionTokens := funk.FilterString(sessionTokens, func(token string) bool { + return token != sessionToken + }) + return sessionToken, updatedSessionTokens + } + } + + return "", sessionTokens +} + +// RemoveOldSessionTokens checks the age of the session tokens and removes the ones that are older than the provided age. +func RemoveOldSessionTokens(sessionTokens []string, age time.Duration) []string { + var validTokens []string + + oneWeekAgo := date_utils.CurrentDate().Add(-1 * age) + for _, sessionTokenString := range sessionTokens { + var sessionToken SessionToken + err := json.Unmarshal([]byte(sessionTokenString), &sessionToken) + if err != nil { + // If we can't unmarshal the token then it is not valid + continue + } + + // Keep the token if it was created in the past week + if sessionToken.TimeCreated.In(date_utils.CurrentLocation()).After(oneWeekAgo) { + validTokens = append(validTokens, sessionTokenString) + } + } + return validTokens +} diff --git a/handler_utils/request.go b/handler_utils/request.go index e383adc99929df3ab9dee872ad05170df2f31675..5633f913580ac8f235a8dd63744a68386ef6411b 100644 --- a/handler_utils/request.go +++ b/handler_utils/request.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "os" + "strings" "time" "github.com/aws/aws-lambda-go/lambdacontext" @@ -65,3 +66,12 @@ func SignAWSHttpRequest(request *http.Request, accessKeyID, secretAccessKey stri return nil } + +func FindHeaderValue(headers map[string]string, key string) string { + for k, v := range headers { + if strings.ToLower(k) == strings.ToLower(key) { + return v + } + } + return "" +}