diff --git a/handler_utils/api.go b/handler_utils/api.go index 0c14596501b845e90a2facd030f94ac596b7e2ac..41d45ca5ce061bd90837292450c2935d5fac1cca 100644 --- a/handler_utils/api.go +++ b/handler_utils/api.go @@ -43,6 +43,28 @@ func ValidateAPIEndpoints(endpoints map[string]map[string]interface{}) (map[stri return endpoints, nil } +// ValidateWebsocketEndpoints checks that all websocket endpoints are correctly defined using one of the supported +// handler types and returns updated endpoints with additional information +func ValidateWebsocketEndpoints(endpoints map[string]interface{}) (map[string]interface{}, error) { + for websocketAction, actionFunc := range endpoints { + if websocketAction == "" { + return nil, errors.Errorf("blank action") + } + if actionFunc == nil { + return nil, errors.Errorf("nil handler on %s %s", websocketAction, actionFunc) + } + + handler, err := NewHandler(actionFunc) + if err != nil { + return nil, errors.Wrapf(err, "%s has invalid handler %T", websocketAction, actionFunc) + } + // replace the endpoint value so that we can quickly call this handler + endpoints[websocketAction] = handler + + } + return endpoints, nil +} + func ValidateRequestParams(request *events.APIGatewayProxyRequest, paramsStructType reflect.Type) (reflect.Value, error) { paramValues := map[string]interface{}{} for n, v := range request.QueryStringParameters { diff --git a/logs/logs.go b/logs/logs.go index 6fc4ce57eb9c27ee5a5c22626a625f64b9a1c431..c901a85ca3a37dd569e0e53f385c8e166d5985b9 100644 --- a/logs/logs.go +++ b/logs/logs.go @@ -369,6 +369,24 @@ func LogResponseInfo(req events.APIGatewayProxyRequest, res events.APIGatewayPro InfoWithFields(fields, "Res") } +func LogFullResponseInfo(res events.APIGatewayProxyResponse, err error) { + if disableLogging { + return + } + + fields := map[string]interface{}{ + "status_code": res.StatusCode, + } + + if err != nil { + fields["error"] = err + } + + fields["body"] = res.Body + + InfoWithFields(fields, "Res") +} + func LogApiAudit(fields log.Fields) { if disableLogging { return @@ -409,6 +427,23 @@ func LogSQSEvent(event events.SQSEvent) { }, "") } +func LogWebsocketEvent(req events.APIGatewayWebsocketProxyRequest, shouldExcludeBody bool) { + if disableLogging { + return + } + + fields := map[string]interface{}{ + "route": req.RequestContext.RouteKey, + "connection_id": req.RequestContext.ConnectionID, + } + + if !shouldExcludeBody { + fields["body"] = req.Body + } + + InfoWithFields(fields, "Req") +} + func SetOutput(out io.Writer) { log.SetOutput(out) } diff --git a/websocket_utils/websocket_utils.go b/websocket_utils/websocket_utils.go new file mode 100644 index 0000000000000000000000000000000000000000..c6a62245e8931fdff591fbdd795602ad5bfd2954 --- /dev/null +++ b/websocket_utils/websocket_utils.go @@ -0,0 +1,73 @@ +package websocket_utils + +import ( + "fmt" + "github.com/aws/aws-lambda-go/events" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/apigatewaymanagementapi" + "gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/errors" + "gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/utils" + "os" +) + +var ( + sessions = map[string]*APIGateWaySessionWithHelpers{} +) + +type APIGateWaySessionWithHelpers struct { + APIGatewaySession *apigatewaymanagementapi.ApiGatewayManagementApi +} + +func GetSession(region ...string) *APIGateWaySessionWithHelpers { + s3Region := os.Getenv("AWS_REGION") + + // Set custom region + if region != nil && len(region) > 0 { + s3Region = region[0] + } + + // Check if session exists for region, if it does return it + if apiGatewaySession, ok := sessions[s3Region]; ok { + return apiGatewaySession + } + + // Setup session + options := session.Options{ + Config: aws.Config{ + Region: utils.ValueToPointer(s3Region), + }, + } + sess, err := session.NewSessionWithOptions(options) + if err != nil { + return nil + } + apiGatewaySession := NewSession(sess) + sessions[s3Region] = apiGatewaySession + return apiGatewaySession +} + +func NewSession(session *session.Session) *APIGateWaySessionWithHelpers { + return &APIGateWaySessionWithHelpers{ + APIGatewaySession: apigatewaymanagementapi.New(session), + } +} + +func (s APIGateWaySessionWithHelpers) PostToConnectionIDs(req *events.APIGatewayWebsocketProxyRequest, connectionIDs []string) error { + if req == nil { + return errors.Error("websocket request is nil") + } + + for _, connectionID := range connectionIDs { + s.APIGatewaySession.Endpoint = fmt.Sprintf("https://%s/%s", req.RequestContext.DomainName, req.RequestContext.Stage) + _, err := s.APIGatewaySession.PostToConnection(&apigatewaymanagementapi.PostToConnectionInput{ + ConnectionId: &connectionID, + Data: []byte("This is a message sent to the requesting connection ID via websocket."), + }) + if err != nil { + return err + } + } + + return nil +}