From 02882f755c0e8b5f250482d091068b020e153903 Mon Sep 17 00:00:00 2001
From: jano3 <jano@bob.co.za>
Date: Thu, 23 May 2024 14:17:30 +0200
Subject: [PATCH] #40 Websocket utils

---
 handler_utils/api.go               | 22 +++++++++
 logs/logs.go                       | 35 ++++++++++++++
 websocket_utils/websocket_utils.go | 73 ++++++++++++++++++++++++++++++
 3 files changed, 130 insertions(+)
 create mode 100644 websocket_utils/websocket_utils.go

diff --git a/handler_utils/api.go b/handler_utils/api.go
index 0c14596..41d45ca 100644
--- a/handler_utils/api.go
+++ b/handler_utils/api.go
@@ -43,6 +43,28 @@ func ValidateAPIEndpoints(endpoints map[string]map[string]interface{}) (map[stri
 	return endpoints, nil
 }
 
+// ValidateWebsocketEndpoints checks that all websocket endpoints are correctly defined using one of the supported
+// handler types and returns updated endpoints with additional information
+func ValidateWebsocketEndpoints(endpoints map[string]interface{}) (map[string]interface{}, error) {
+	for websocketAction, actionFunc := range endpoints {
+		if websocketAction == "" {
+			return nil, errors.Errorf("blank action")
+		}
+		if actionFunc == nil {
+			return nil, errors.Errorf("nil handler on %s %s", websocketAction, actionFunc)
+		}
+
+		handler, err := NewHandler(actionFunc)
+		if err != nil {
+			return nil, errors.Wrapf(err, "%s has invalid handler %T", websocketAction, actionFunc)
+		}
+		// replace the endpoint value so that we can quickly call this handler
+		endpoints[websocketAction] = handler
+
+	}
+	return endpoints, nil
+}
+
 func ValidateRequestParams(request *events.APIGatewayProxyRequest, paramsStructType reflect.Type) (reflect.Value, error) {
 	paramValues := map[string]interface{}{}
 	for n, v := range request.QueryStringParameters {
diff --git a/logs/logs.go b/logs/logs.go
index 6fc4ce5..c901a85 100644
--- a/logs/logs.go
+++ b/logs/logs.go
@@ -369,6 +369,24 @@ func LogResponseInfo(req events.APIGatewayProxyRequest, res events.APIGatewayPro
 	InfoWithFields(fields, "Res")
 }
 
+func LogFullResponseInfo(res events.APIGatewayProxyResponse, err error) {
+	if disableLogging {
+		return
+	}
+
+	fields := map[string]interface{}{
+		"status_code": res.StatusCode,
+	}
+
+	if err != nil {
+		fields["error"] = err
+	}
+
+	fields["body"] = res.Body
+
+	InfoWithFields(fields, "Res")
+}
+
 func LogApiAudit(fields log.Fields) {
 	if disableLogging {
 		return
@@ -409,6 +427,23 @@ func LogSQSEvent(event events.SQSEvent) {
 	}, "")
 }
 
+func LogWebsocketEvent(req events.APIGatewayWebsocketProxyRequest, shouldExcludeBody bool) {
+	if disableLogging {
+		return
+	}
+
+	fields := map[string]interface{}{
+		"route":         req.RequestContext.RouteKey,
+		"connection_id": req.RequestContext.ConnectionID,
+	}
+
+	if !shouldExcludeBody {
+		fields["body"] = req.Body
+	}
+
+	InfoWithFields(fields, "Req")
+}
+
 func SetOutput(out io.Writer) {
 	log.SetOutput(out)
 }
diff --git a/websocket_utils/websocket_utils.go b/websocket_utils/websocket_utils.go
new file mode 100644
index 0000000..c6a6224
--- /dev/null
+++ b/websocket_utils/websocket_utils.go
@@ -0,0 +1,73 @@
+package websocket_utils
+
+import (
+	"fmt"
+	"github.com/aws/aws-lambda-go/events"
+	"github.com/aws/aws-sdk-go/aws"
+	"github.com/aws/aws-sdk-go/aws/session"
+	"github.com/aws/aws-sdk-go/service/apigatewaymanagementapi"
+	"gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/errors"
+	"gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/utils"
+	"os"
+)
+
+var (
+	sessions = map[string]*APIGateWaySessionWithHelpers{}
+)
+
+type APIGateWaySessionWithHelpers struct {
+	APIGatewaySession *apigatewaymanagementapi.ApiGatewayManagementApi
+}
+
+func GetSession(region ...string) *APIGateWaySessionWithHelpers {
+	s3Region := os.Getenv("AWS_REGION")
+
+	// Set custom region
+	if region != nil && len(region) > 0 {
+		s3Region = region[0]
+	}
+
+	// Check if session exists for region, if it does return it
+	if apiGatewaySession, ok := sessions[s3Region]; ok {
+		return apiGatewaySession
+	}
+
+	// Setup session
+	options := session.Options{
+		Config: aws.Config{
+			Region: utils.ValueToPointer(s3Region),
+		},
+	}
+	sess, err := session.NewSessionWithOptions(options)
+	if err != nil {
+		return nil
+	}
+	apiGatewaySession := NewSession(sess)
+	sessions[s3Region] = apiGatewaySession
+	return apiGatewaySession
+}
+
+func NewSession(session *session.Session) *APIGateWaySessionWithHelpers {
+	return &APIGateWaySessionWithHelpers{
+		APIGatewaySession: apigatewaymanagementapi.New(session),
+	}
+}
+
+func (s APIGateWaySessionWithHelpers) PostToConnectionIDs(req *events.APIGatewayWebsocketProxyRequest, connectionIDs []string) error {
+	if req == nil {
+		return errors.Error("websocket request is nil")
+	}
+
+	for _, connectionID := range connectionIDs {
+		s.APIGatewaySession.Endpoint = fmt.Sprintf("https://%s/%s", req.RequestContext.DomainName, req.RequestContext.Stage)
+		_, err := s.APIGatewaySession.PostToConnection(&apigatewaymanagementapi.PostToConnectionInput{
+			ConnectionId: &connectionID,
+			Data:         []byte("This is a message sent to the requesting connection ID via websocket."),
+		})
+		if err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
-- 
GitLab