diff --git a/secrets_manager/secrets_manager.go b/secrets_manager/secrets_manager.go index dfe443f2349c1ad0847c0d2331340b08beac2ecd..812b5534add89f1fb04860013c0b7b6b6a7ec209 100644 --- a/secrets_manager/secrets_manager.go +++ b/secrets_manager/secrets_manager.go @@ -2,6 +2,7 @@ package secrets_manager import ( "encoding/base64" + "encoding/json" credentials2 "github.com/aws/aws-sdk-go/aws/credentials" "os" @@ -35,6 +36,8 @@ var ( secretManagerRegion = "af-south-1" ) +var secretManagerSession *secretsmanager.SecretsManager + func GetDatabaseCredentials(secretID string, isDebug bool) (DatabaseCredentials, error) { secret, _ := GetSecret(secretID, isDebug) var credentials DatabaseCredentials @@ -55,20 +58,25 @@ func GetS3UploadCredentials(secretID string, isDebug bool) (*credentials2.Creden return credentials2.NewStaticCredentials(credentials.AccessKeyID, credentials.SecretKey, ""), nil } -func GetSecret(secretID string, isDebug bool) (string, string) { - cachedSecret, err := secretCache.GetSecretString(secretID) - if err != nil { - logs.Info("Failed to get secret key from cache") - } - if cachedSecret != "" { - return cachedSecret, "" +// getSecretManagerSession Instantiates a new Secrets Manager client session +func getSecretManagerSession(isDebug bool) (err error) { + // If a session already exists, use it + if secretManagerSession != nil { + return nil } - awsSession := session.New() + logs.Info("Creating a new Secrets Manager session") + awsSession, err := session.NewSession() + if err != nil { + return err + } // Get local config if isDebug && os.Getenv("ENVIRONMENT") != "" { - logs.Info("Using access key %s", os.Getenv("AWS_ACCESS_KEY_ID")) + awsAccessKey := os.Getenv("AWS_ACCESS_KEY_ID") + if len(awsAccessKey) > 0 { + logs.Info("Using access key %s", awsAccessKey) + } awsSession, err = session.NewSessionWithOptions(session.Options{ Config: aws.Config{ Region: aws.String("af-south-1"), @@ -76,53 +84,53 @@ func GetSecret(secretID string, isDebug bool) (string, string) { }, }) if err != nil { - return "", "" + return err } } + // Create a Secrets Manager client session + secretManagerSession = secretsmanager.New(awsSession, aws.NewConfig().WithRegion(secretManagerRegion)) + + return nil +} + +// logError Logs any errors returned by the Secrets Manager client +func logError(err error) { + if aerr, ok := err.(awserr.Error); ok { + logs.Info(aerr.Code()+" %s", aerr.Error()) + } else { + // Print the error, cast err to awserr.Error to get the Code and + // Message from an error. + logs.Info(err.Error()) + } +} + +func GetSecret(secretID string, isDebug bool) (string, string) { + // Check if we have the secret in cache + cachedSecret, err := secretCache.GetSecretString(secretID) + if err != nil { + logs.Info("Failed to get secret key from cache") + } + if cachedSecret != "" { + return cachedSecret, "" + } + // Create a Secrets Manager client - svc := secretsmanager.New(awsSession, aws.NewConfig().WithRegion(secretManagerRegion)) + err = getSecretManagerSession(isDebug) + if err != nil { + logs.Info("Could not create client: %+v", err) + return "", "" + } + // Create a secret input := &secretsmanager.GetSecretValueInput{ SecretId: aws.String(string(secretID)), VersionStage: aws.String("AWSCURRENT"), // VersionStage defaults to AWSCURRENT if unspecified } - // In this sample we only handle the specific exceptions for the 'GetSecretValue' API. - // See https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html - - result, err := svc.GetSecretValue(input) + result, err := secretManagerSession.GetSecretValue(input) if err != nil { - if aerr, ok := err.(awserr.Error); ok { - switch aerr.Code() { - case secretsmanager.ErrCodeDecryptionFailure: - // Secrets Manager can't decrypt the protected secret text using the provided KMS key. - logs.Info(secretsmanager.ErrCodeDecryptionFailure, aerr.Error()) - - case secretsmanager.ErrCodeInternalServiceError: - // An error occurred on the server side. - logs.Info(secretsmanager.ErrCodeInternalServiceError, aerr.Error()) - - case secretsmanager.ErrCodeInvalidParameterException: - // You provided an invalid value for a parameter. - logs.Info(secretsmanager.ErrCodeInvalidParameterException, aerr.Error()) - - case secretsmanager.ErrCodeInvalidRequestException: - // You provided a parameter value that is not valid for the current state of the resource. - logs.Info(secretsmanager.ErrCodeInvalidRequestException, aerr.Error()) - - case secretsmanager.ErrCodeResourceNotFoundException: - // We can't find the resource that you asked for. - logs.Info("Can't find secret with ID: ", secretID) - logs.Info(secretsmanager.ErrCodeResourceNotFoundException, aerr.Error()) - default: - logs.Info(err.Error()) - } - } else { - // Print the error, cast err to awserr.Error to get the Code and - // Message from an error. - logs.Info(err.Error()) - } + logError(err) return "", "" } @@ -135,7 +143,7 @@ func GetSecret(secretID string, isDebug bool) (string, string) { decodedBinarySecretBytes := make([]byte, base64.StdEncoding.DecodedLen(len(result.SecretBinary))) length, err := base64.StdEncoding.Decode(decodedBinarySecretBytes, result.SecretBinary) if err != nil { - logs.Info("Base64 Decode Error:", err) + logs.Info("Base64 Decode Error: %+v", err) return "", "" } decodedBinarySecret = string(decodedBinarySecretBytes[:length]) @@ -143,3 +151,55 @@ func GetSecret(secretID string, isDebug bool) (string, string) { return secretString, decodedBinarySecret } + +// CreateSecret Creates a JSON marshaled "string secret" (can be expanded to cater for binary secrets should the need arise) +func CreateSecret(secretID string, secret any, isDebug bool) (awsSecretID string, err error) { + // Create a Secrets Manager client + err = getSecretManagerSession(isDebug) + if err != nil { + logs.Info("Could not create client: %+v", err) + return "", err + } + + // Create the secret - marshaling "any" into a JSON string + secretStr, err := json.Marshal(secret) + if err != nil { + logs.Info("Could not marshal secret: %+v", err) + return "", err + } + input := &secretsmanager.CreateSecretInput{ + Name: aws.String(secretID), + SecretString: aws.String(string(secretStr)), + } + + result, err := secretManagerSession.CreateSecret(input) + if err != nil { + logError(err) + return "", err + } + + return aws.StringValue(result.Name), nil +} + +func DeleteSecret(secretID string, forceWithoutRecovery bool, isDebug bool) error { + // Create a Secrets Manager client + err := getSecretManagerSession(isDebug) + if err != nil { + logs.Info("Could not create client: %+v", err) + return err + } + + // Delete the secret + input := &secretsmanager.DeleteSecretInput{ + SecretId: aws.String(secretID), + ForceDeleteWithoutRecovery: aws.Bool(forceWithoutRecovery), + } + + _, err = secretManagerSession.DeleteSecret(input) + if err != nil { + logError(err) + return err + } + + return nil +} diff --git a/secrets_manager/secrets_manager_test.go b/secrets_manager/secrets_manager_test.go new file mode 100644 index 0000000000000000000000000000000000000000..3e5d15d376b119b7824996a4eab7c2e538f09986 --- /dev/null +++ b/secrets_manager/secrets_manager_test.go @@ -0,0 +1,71 @@ +package secrets_manager + +import ( + "gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/date_utils" + "os" + "testing" + "time" +) + +var isDebug bool +var secretID = "TestSecret_" + time.Now().Format(date_utils.DateLayoutTrimmed()) + +func TestMain(m *testing.M) { + isDebug = true + os.Setenv("ENVIRONMENT", "dev") + os.Setenv("AWS_PROFILE", "") // <-- Use your AWS profile name here + + code := m.Run() + os.Exit(code) +} + +func TestAll(t *testing.T) { + testCreateSecret(t) + testGetSecret(t) + testDeleteSecret(t) +} + +func testCreateSecret(t *testing.T) { + type SubStruct struct { + Arg3a string + Arg3b string + } + type Anything struct { + Arg1 string + Arg2 string + Arg3 SubStruct + } + secret := Anything{ + Arg1: "lorem", + Arg2: "ipsum", + Arg3: SubStruct{ + Arg3a: "dolor", + Arg3b: "sit", + }, + } + + secretName, err := CreateSecret(secretID, secret, isDebug) + if err != nil { + t.Errorf("Secret with ID '%s' could not be created.", secretName) + } + + t.Logf("Secret with ID '%s' successfully created.", secretName) +} + +func testGetSecret(t *testing.T) { + secret, _ := GetSecret(secretID, isDebug) + if len(secret) <= 0 { + t.Errorf("Could not get secret with ID %s, or secret has no content", secretID) + } + + t.Logf("Secret with ID `%s` has content: %s", secretID, secret) +} + +func testDeleteSecret(t *testing.T) { + err := DeleteSecret(secretID, true, isDebug) + if err != nil { + t.Errorf("Secret with ID '%s' could not be deleted.", secretID) + return + } + t.Logf("Secret with ID '%s' successfully deleted.", secretID) +}