diff --git a/websocket_utils/websocket_utils.go b/websocket_utils/websocket_utils.go index 7c508a79a3e1f7ce47ff0e8807678cd70dd1428f..954c010f9005678535d3e1e91e2acefd1fd5aa59 100644 --- a/websocket_utils/websocket_utils.go +++ b/websocket_utils/websocket_utils.go @@ -3,7 +3,6 @@ package websocket_utils import ( "context" "fmt" - "github.com/aws/aws-lambda-go/events" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/apigatewaymanagementapi" @@ -18,14 +17,17 @@ var ( type APIGateWayClientWithHelpers struct { APIGatewayClient *apigatewaymanagementapi.Client + Available bool } -func GetClient(req *events.APIGatewayWebsocketProxyRequest, region ...string) *APIGateWayClientWithHelpers { - if req == nil { - logs.ErrorMsg("APIGatewayWebsocketProxyRequest is nil") - return nil +func GetClient(region ...string) *APIGateWayClientWithHelpers { + domainName := os.Getenv("WEBSOCKET_DOMAIN_NAME") + if domainName == "" { + logs.ErrorMsg("WEBSOCKET_API_DOMAIN_NAME env variable is not set") + return &APIGateWayClientWithHelpers{} } + env := utils.GetEnv("ENVIRONMENT", "dev") s3Region := os.Getenv("AWS_REGION") // Set custom region @@ -43,33 +45,43 @@ func GetClient(req *events.APIGatewayWebsocketProxyRequest, region ...string) *A config.WithRegion(s3Region), ) if err != nil { - return nil + logs.ErrorMsg("failed to load AWS config for websocket client") + return &APIGateWayClientWithHelpers{} } - apiGatewaySession := NewClient(req, cfg) + apiGatewaySession := NewClient(cfg, domainName, env) sessions[s3Region] = apiGatewaySession return apiGatewaySession } -func NewClient(req *events.APIGatewayWebsocketProxyRequest, config aws.Config) *APIGateWayClientWithHelpers { +func NewClient(config aws.Config, domainName string, env string) *APIGateWayClientWithHelpers { return &APIGateWayClientWithHelpers{ APIGatewayClient: apigatewaymanagementapi.NewFromConfig(config, func(o *apigatewaymanagementapi.Options) { - o.BaseEndpoint = utils.ValueToPointer(fmt.Sprintf("https://%s/%s", req.RequestContext.DomainName, req.RequestContext.Stage)) + o.BaseEndpoint = utils.ValueToPointer(fmt.Sprintf("https://%s/%s", domainName, env)) }), + Available: true, } } -func (s APIGateWayClientWithHelpers) PostToConnectionID(data []byte, connectionID string) error { - _, err := s.APIGatewayClient.PostToConnection(context.TODO(), &apigatewaymanagementapi.PostToConnectionInput{ +func (a APIGateWayClientWithHelpers) PostToConnectionID(data []byte, connectionID string) error { + if !a.Available { + return nil + } + + _, err := a.APIGatewayClient.PostToConnection(context.TODO(), &apigatewaymanagementapi.PostToConnectionInput{ ConnectionId: &connectionID, Data: data, }) return err } -func (s APIGateWayClientWithHelpers) PostToConnectionIDs(data []byte, connectionIDs []string) { +func (a APIGateWayClientWithHelpers) PostToConnectionIDs(data []byte, connectionIDs []string) { + if !a.Available { + return + } + for _, connectionID := range connectionIDs { - _, err := s.APIGatewayClient.PostToConnection(context.TODO(), &apigatewaymanagementapi.PostToConnectionInput{ + _, err := a.APIGatewayClient.PostToConnection(context.TODO(), &apigatewaymanagementapi.PostToConnectionInput{ ConnectionId: &connectionID, Data: data, })