diff --git a/handler_utils/sqs.go b/handler_utils/sqs.go index 5a923583467111281331495c8d24d19b02eba10a..855bf1fcd1052f029402ec002c2cf12f96570288 100644 --- a/handler_utils/sqs.go +++ b/handler_utils/sqs.go @@ -5,6 +5,8 @@ import ( "github.com/aws/aws-lambda-go/events" "gitlab.com/uafrica/go-utils/errors" "gitlab.com/uafrica/go-utils/logs" + "gitlab.com/uafrica/go-utils/s3" + "gitlab.com/uafrica/go-utils/sqs" "reflect" ) @@ -30,11 +32,29 @@ func ValidateSQSEndpoints(endpoints map[string]interface{}) (map[string]interfac return endpoints, nil } -func GetRecord(message events.SQSMessage, recordType reflect.Type) (interface{}, error) { +func GetRecord(s3Session *s3.SessionWithHelpers, bucket string, message events.SQSMessage, recordType reflect.Type) (interface{}, error) { + recordValuePtr := reflect.New(recordType) - err := json.Unmarshal([]byte(message.Body), recordValuePtr.Interface()) - if err != nil { - return nil, errors.Wrapf(err, "failed to JSON decode message body") + + // Check if message body should be retrieved from S3 + if messageAttribute, ok := message.MessageAttributes[sqs.SQSMessageOnS3Key]; ok { + if messageAttribute.StringValue != nil && *messageAttribute.StringValue == "true" { + messageBytes, err := sqs.RetrieveMessageFromS3(s3Session, bucket, message.Body) + if err != nil { + return nil, errors.Wrapf(err, "failed to get sqs message body from s3") + } + + err = json.Unmarshal(messageBytes, recordValuePtr.Interface()) + if err != nil { + return nil, errors.Wrapf(err, "failed to JSON decode message body") + } + } + } else { + // Message was small enough, it is contained in the message body + err := json.Unmarshal([]byte(message.Body), recordValuePtr.Interface()) + if err != nil { + return nil, errors.Wrapf(err, "failed to JSON decode message body") + } } if validator, ok := recordValuePtr.Interface().(IValidator); ok { diff --git a/sqs/sqs.go b/sqs/sqs.go index 1260de16037fe5dfe0a2d5e4a65634b6a070c23f..cc71f3e862ebd2d2d5257345530048ff293fef90 100644 --- a/sqs/sqs.go +++ b/sqs/sqs.go @@ -5,7 +5,11 @@ package sqs import ( "encoding/json" "fmt" - "os" + "github.com/google/uuid" + "gitlab.com/uafrica/go-utils/s3" + "gitlab.com/uafrica/go-utils/string_utils" + "io/ioutil" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" @@ -13,18 +17,26 @@ import ( "gitlab.com/uafrica/go-utils/logs" ) +var sqsClient *sqs.SQS + +const SQSMessageOnS3Key = "message-on-s3" + // Messenger sends an arbitrary message via SQS type Messenger struct { - session *session.Session - service *sqs.SQS - queueURL string + QueueName string + QueueURL string + Region string + S3Session *s3.SessionWithHelpers + S3BucketName string + MessageGroupID *string + RequestIDHeaderKey string } -// NewSQSMessenger constructs a Messenger which sends messages to an SQS queue +// NewSQSClient constructs a Messenger which sends messages to an SQS queue // awsRegion - region that the queue was created // awsQueue - name of the queue // Note: Calling code needs SQS IAM permissions -func NewSQSMessenger(awsRegion, queueUrl string) (*Messenger, error) { +func NewSQSClient(awsRegion string) (*sqs.SQS, error) { // Make an AWS session sess, err := session.NewSessionWithOptions(session.Options{ Config: aws.Config{ @@ -37,24 +49,14 @@ func NewSQSMessenger(awsRegion, queueUrl string) (*Messenger, error) { } // Create SQS service - svc := sqs.New(sess) - - return &Messenger{ - session: sess, - service: svc, - queueURL: queueUrl, - }, nil + sqsClient = sqs.New(sess) + return sqsClient, err } // SendSQSMessage sends a message to the queue associated with the messenger // headers - string message attributes of the SQS message (see AWS SQS documentation) // body - body of the SQS message (see AWS SQS documentation) -func (m *Messenger) SendSQSMessage(headers map[string]string, body string, currentRequestID *string, sqsType string, headerKey string, messageGroupID ...string) (string, error) { - msgGrpID := "" - if len(messageGroupID) > 0 && messageGroupID[0] != "" { - msgGrpID = messageGroupID[0] - } - +func (m *Messenger) SendSQSMessage(headers map[string]string, body string, currentRequestID *string, sqsType string) (string, error) { msgAttrs := make(map[string]*sqs.MessageAttributeValue) for key, val := range headers { @@ -66,7 +68,7 @@ func (m *Messenger) SendSQSMessage(headers map[string]string, body string, curre // Add request ID if currentRequestID != nil { - msgAttrs[headerKey] = &sqs.MessageAttributeValue{ + msgAttrs[m.RequestIDHeaderKey] = &sqs.MessageAttributeValue{ DataType: aws.String("String"), StringValue: aws.String(*currentRequestID), } @@ -79,18 +81,18 @@ func (m *Messenger) SendSQSMessage(headers map[string]string, body string, curre var res *sqs.SendMessageOutput var err error - if msgGrpID == "" { - res, err = m.service.SendMessage(&sqs.SendMessageInput{ + if string_utils.UnwrapString(m.MessageGroupID) == "" { + res, err = sqsClient.SendMessage(&sqs.SendMessageInput{ MessageAttributes: msgAttrs, MessageBody: aws.String(body), - QueueUrl: &m.queueURL, + QueueUrl: &m.QueueURL, }) } else { - res, err = m.service.SendMessage(&sqs.SendMessageInput{ + res, err = sqsClient.SendMessage(&sqs.SendMessageInput{ MessageAttributes: msgAttrs, MessageBody: aws.String(body), - QueueUrl: &m.queueURL, - MessageGroupId: &msgGrpID, + QueueUrl: &m.QueueURL, + MessageGroupId: m.MessageGroupID, }) } @@ -101,17 +103,12 @@ func (m *Messenger) SendSQSMessage(headers map[string]string, body string, curre return *res.MessageId, err } -func SendSQSMessage(msgr *Messenger, region string, envQueueURLName string, objectToSend interface{}, currentRequestID *string, sqsType string, headerKey string, messageGroupID ...string) error { - msgGrpID := "" - if len(messageGroupID) > 0 && messageGroupID[0] != "" { - msgGrpID = messageGroupID[0] - } - - if msgr == nil { +func SendSQSMessage(msgr Messenger, objectToSend interface{}, currentRequestID *string, sqsType string) error { + if sqsClient == nil { var err error - msgr, err = NewSQSMessenger(region, os.Getenv(envQueueURLName)) + sqsClient, err = NewSQSClient(msgr.Region) if err != nil { - logs.ErrorWithMsg("Failed to create sqs messenger with envQueueURLName: "+envQueueURLName, err) + logs.ErrorWithMsg("Failed to create sqs client", err) } } @@ -121,13 +118,64 @@ func SendSQSMessage(msgr *Messenger, region string, envQueueURLName string, obje return err } - headers := map[string]string{"Name": "dummy"} - msgID, err := msgr.SendSQSMessage(headers, string(jsonBytes), currentRequestID, sqsType, headerKey, msgGrpID) + message := string(jsonBytes) + headers := map[string]string{ + "Name": "dummy", + SQSMessageOnS3Key: "false", + } + + // If bigger than 200 KB upload message to s3 and send s3 file name to sqs + // The sqs receiver should check the header to see if the message is in the body or on s3 + if len(jsonBytes) > 0 { + headers[SQSMessageOnS3Key] = "true" + id := uuid.New() + filename := fmt.Sprintf("%v-%v", sqsType, id.String()) + + err := uploadMessageToS3(msgr.S3Session, msgr.S3BucketName, filename, jsonBytes) + if err != nil { + return err + } + + message = filename // Send filename as the message + } + + msgID, err := msgr.SendSQSMessage(headers, message, currentRequestID, sqsType) if err != nil { logs.ErrorWithMsg("Failed to send sqs event", err) return err } - logs.Info(fmt.Sprintf("Sent SQS message to %s with ID %s", envQueueURLName, msgID)) + logs.Info(fmt.Sprintf("Sent SQS message to %s with ID %s", msgr.QueueName, msgID)) return nil } + +func uploadMessageToS3(session *s3.SessionWithHelpers, bucket string, name string, messageBytes []byte) error { + // Upload message + expiry := 24 * 7 * time.Hour // 3 days + _, err := session.UploadWithSettings(messageBytes, bucket, name, s3.S3UploadSettings{ + ExpiryDuration: &expiry, + }) + if err != nil { + return err + } + + return nil +} + +func RetrieveMessageFromS3(session *s3.SessionWithHelpers, bucket string, filename string) ([]byte, error) { + // get the file contents + rawObject, err := session.GetObject(bucket, filename, false) + if err != nil { + return []byte{}, err + } + + // Read the message + var bodyBytes []byte + bodyBytes, err = ioutil.ReadAll(rawObject.Body) + if err != nil { + logs.ErrorWithMsg("Could not read file", err) + return []byte{}, err + } + + return bodyBytes, nil +}