Skip to content
Snippets Groups Projects
Commit 4068ac11 authored by Jano Hendriks's avatar Jano Hendriks
Browse files

Add session token functions for JWT

parent 4a4d4a2b
No related branches found
No related tags found
No related merge requests found
package auth
import (
"gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/handler_utils"
"golang.org/x/crypto/bcrypt"
"strings"
)
// GetBearerTokenFromHeaders checks if a bearer token is passed as part of the Authorization header and returns that key
func GetBearerTokenFromHeaders(headers map[string]string) string {
headerValue := headers["authorization"]
if headerValue == "" {
headerValue = headers["Authorization"]
}
headerValue := handler_utils.FindHeaderValue(headers, "authorization")
if strings.HasPrefix(strings.ToLower(headerValue), "bearer ") {
headerValues := strings.Split(headerValue, " ")
return strings.TrimSpace(headerValues[1])
......
......@@ -11,13 +11,19 @@ import (
type JsonWebToken struct {
UserID string `json:"user_id"`
Password string `json:"password"`
ProviderID int64 `json:"provider_id,omitempty"`
ExpiryDate time.Time `json:"expiry_date"`
}
// GenerateJWTWithSessionToken first signs the session token with the secret key, then takes the payload and generates a
// signed JWT using the resulting signed session token
func GenerateJWTWithSessionToken(payload JsonWebToken, secretKey string, sessionToken string) (string, error) {
signedSessionToken := SignSessionTokenWithKey(secretKey, sessionToken)
return GenerateJWT(payload, signedSessionToken)
}
// GenerateJWT takes the payload and generates a signed JWT using the provided secret
func GenerateJWT(payload JsonWebToken, secret []byte) (string, error) {
func GenerateJWT(payload JsonWebToken, secretKey string) (string, error) {
// Convert the JsonWebToken to a map[string]interface{}
tokenBytes, err := json.Marshal(payload)
if err != nil {
......@@ -32,7 +38,7 @@ func GenerateJWT(payload JsonWebToken, secret []byte) (string, error) {
// Create the signed token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims(tokenMap))
tokenString, err := token.SignedString(secret)
tokenString, err := token.SignedString([]byte(secretKey))
if err != nil {
return "", err
}
......@@ -40,25 +46,13 @@ func GenerateJWT(payload JsonWebToken, secret []byte) (string, error) {
return tokenString, nil
}
// ValidateJWT parses the JWT and validates that it is signed correctly
func ValidateJWT(tokenString string, secret []byte) (JsonWebToken, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Validate the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errors.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return secret, nil
})
if err != nil {
return JsonWebToken{}, err
}
func getJsonWebTokenFromTokenClaims(token *jwt.Token, checkValidity bool) (JsonWebToken, error) {
if token == nil {
return JsonWebToken{}, errors.Error("could not get token from token string")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
if !ok || (checkValidity && token.Valid == false) {
return JsonWebToken{}, errors.Error("invalid token")
}
......@@ -74,6 +68,37 @@ func ValidateJWT(tokenString string, secret []byte) (JsonWebToken, error) {
return JsonWebToken{}, err
}
return jsonWebToken, nil
}
// ValidateJWTWithSessionToken first signs the session token using the secret key, then parses the JWT and validates
// that it is signed correctly
func ValidateJWTWithSessionToken(jsonWebTokenString string, secretKey string, sessionToken string) (JsonWebToken, error) {
// Sign the session token with the secret key - this prevents the JWT from being used by other sessions
signedSecret := SignSessionTokenWithKey(secretKey, sessionToken)
return ValidateJWT(jsonWebTokenString, signedSecret)
}
// ValidateJWT parses the JWT and validates that it is signed correctly
func ValidateJWT(jsonWebTokenString string, secretKey string) (JsonWebToken, error) {
// Validate the
token, err := jwt.Parse(jsonWebTokenString, func(token *jwt.Token) (interface{}, error) {
// Validate the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errors.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(secretKey), nil
})
if err != nil {
return JsonWebToken{}, err
}
jsonWebToken, err := getJsonWebTokenFromTokenClaims(token, true)
if err != nil {
return JsonWebToken{}, err
}
// Validate the expiry date
if jsonWebToken.ExpiryDate.Before(date_utils.CurrentDate()) {
return jsonWebToken, errors.Error("token has expired")
......@@ -86,8 +111,33 @@ func ValidateJWT(tokenString string, secret []byte) (JsonWebToken, error) {
// using the provided encryption key.
func LoginWithPassword(password string, hashedPassword string, jsonWebToken JsonWebToken, jwtEncryptionKey string) (string, error) {
if PasswordIsCorrect(password, hashedPassword) {
return GenerateJWT(jsonWebToken, []byte(jwtEncryptionKey))
return GenerateJWT(jsonWebToken, jwtEncryptionKey)
}
return "", errors.HTTPWithMsg(http.StatusBadRequest, "password is incorrect")
}
// LoginSessionWithPassword checks that the provided password is correct. If the password is correct, the session token
// is signed using the secret key, and a JWT is returned using the signed session token
func LoginSessionWithPassword(password string, hashedPassword string, jsonWebToken JsonWebToken, secretKey string, sessionToken string) (string, error) {
if PasswordIsCorrect(password, hashedPassword) {
return GenerateJWTWithSessionToken(jsonWebToken, secretKey, sessionToken)
}
return "", errors.HTTPWithMsg(http.StatusBadRequest, "password is incorrect")
}
// GetUserIDFromJWTWithoutValidation gets the userID from the jsonWebTokenString without validating the signature.
// Successful execution of this function DOES NOT indicate that the JWT is valid in any way.
func GetUserIDFromJWTWithoutValidation(jsonWebTokenString string) string {
token, _, err := jwt.NewParser().ParseUnverified(jsonWebTokenString, jwt.MapClaims{})
if err != nil {
return ""
}
jsonWebToken, err := getJsonWebTokenFromTokenClaims(token, false)
if err != nil {
return ""
}
return jsonWebToken.UserID
}
package auth
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"github.com/aws/aws-lambda-go/events"
"github.com/google/uuid"
"github.com/thoas/go-funk"
"gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/date_utils"
"gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/handler_utils"
"time"
)
type SessionToken struct {
IP string `json:"ip"`
UserAgent string `json:"user_agent"`
TimeCreated time.Time `json:"time_created"`
Token string `json:"session_token"`
}
// SignSessionTokenWithKey signs the session token string using the secret.
func SignSessionTokenWithKey(secretKey string, sessionTokenString string) string {
signer := hmac.New(sha256.New, []byte(secretKey))
signer.Write([]byte(sessionTokenString))
return hex.EncodeToString(signer.Sum(nil))
}
// GetSessionTokenString creates a unique session token string from the API request.
func GetSessionTokenString(request events.APIGatewayProxyRequest) (string, error) {
sessionToken := SessionToken{
IP: request.RequestContext.Identity.SourceIP,
UserAgent: handler_utils.FindHeaderValue(request.Headers, "user-agent"),
TimeCreated: date_utils.CurrentDate().Round(time.Second),
Token: uuid.New().String(),
}
sessionTokenBytes, err := json.Marshal(sessionToken)
return string(sessionTokenBytes), err
}
// GetSignedSessionTokenString creates a unique session token string from the API request and signs it using the secret.
func GetSignedSessionTokenString(request events.APIGatewayProxyRequest, secretKey string) (string, error) {
sessionTokenBytes, err := GetSessionTokenString(request)
if err != nil {
return "", err
}
return SignSessionTokenWithKey(secretKey, sessionTokenBytes), nil
}
// ValidateJWTWithSessionTokens attempts to validate the JWT string by signing each session token using the secret, and
// using the resulting signed session token to validate the JWT. If the JWT can be validated using a session token, the
// JsonWebToken is returned, otherwise nil is returned.
func ValidateJWTWithSessionTokens(jsonWebTokenString string, secretKey string, sessionTokens []string) *JsonWebToken {
// Test each session token to find one that is valid
for _, sessionToken := range sessionTokens {
jsonWebToken, err := ValidateJWTWithSessionToken(jsonWebTokenString, secretKey, sessionToken)
if err == nil {
return &jsonWebToken
}
}
return nil
}
// FindAndRemoveCurrentSessionToken attempts to validate the JWT string by signing each session token using the secret,
// and using the resulting signed session token to validate the JWT. If the JWT is successfully validated with one of
// the session tokens, the session token is removed from the slice, otherwise the original session token
// slice is returned.
func FindAndRemoveCurrentSessionToken(jsonWebTokenString string, secretKey string, sessionTokens []string) (string, []string) {
// Test each session token to find one that is valid
for _, sessionToken := range sessionTokens {
_, err := ValidateJWTWithSessionToken(jsonWebTokenString, secretKey, sessionToken)
if err == nil {
// Remove this session token from the slice
updatedSessionTokens := funk.FilterString(sessionTokens, func(token string) bool {
return token != sessionToken
})
return sessionToken, updatedSessionTokens
}
}
return "", sessionTokens
}
// RemoveOldSessionTokens checks the age of the session tokens and removes the ones that are older than the provided age.
func RemoveOldSessionTokens(sessionTokens []string, age time.Duration) []string {
var validTokens []string
oneWeekAgo := date_utils.CurrentDate().Add(-1 * age)
for _, sessionTokenString := range sessionTokens {
var sessionToken SessionToken
err := json.Unmarshal([]byte(sessionTokenString), &sessionToken)
if err != nil {
// If we can't unmarshal the token then it is not valid
continue
}
// Keep the token if it was created in the past week
if sessionToken.TimeCreated.In(date_utils.CurrentLocation()).After(oneWeekAgo) {
validTokens = append(validTokens, sessionTokenString)
}
}
return validTokens
}
......@@ -9,6 +9,7 @@ import (
"io"
"net/http"
"os"
"strings"
"time"
"github.com/aws/aws-lambda-go/lambdacontext"
......@@ -65,3 +66,12 @@ func SignAWSHttpRequest(request *http.Request, accessKeyID, secretAccessKey stri
return nil
}
func FindHeaderValue(headers map[string]string, key string) string {
for k, v := range headers {
if strings.ToLower(k) == strings.ToLower(key) {
return v
}
}
return ""
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment