diff --git a/ip_utils/ip_utils.go b/ip_utils/ip_utils.go index a72f734061df41eab7b4ee2e81d369b664ec9892..a1ee2dd4145075e4156a29fee56b89fe06c139e7 100644 --- a/ip_utils/ip_utils.go +++ b/ip_utils/ip_utils.go @@ -2,7 +2,9 @@ package ip_utils import ( "fmt" + "github.com/aws/aws-lambda-go/events" "gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/errors" + "gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/handler_utils" "net" "os" "strings" @@ -163,4 +165,27 @@ func ValidateIPAddress(ipAddress string) (cleanedIPAddress string, err error) { return ipAddress, nil } +func GetRequestSourceIP(proxyRequest *events.APIGatewayProxyRequest, websocketReqeuest *events.APIGatewayWebsocketProxyRequest) string { + var requestSourceIP string + if proxyRequest != nil { + requestSourceIP = proxyRequest.RequestContext.Identity.SourceIP + // Cloudflare uses this header to pass the real IP + forwardedForHeader := handler_utils.FindHeaderValue(proxyRequest.Headers, "x-forwarded-for") + if forwardedForHeader != "" && + VerifyCloudflareSourceIP(requestSourceIP) { + forwardedForHeaderIPs := strings.Split(forwardedForHeader, ",") + + if len(forwardedForHeaderIPs) > 0 { + // Use the first IP as the source IP + headerSourceIP := strings.TrimSpace(forwardedForHeaderIPs[len(forwardedForHeaderIPs)-1]) + return headerSourceIP + } + } + } else if websocketReqeuest != nil { + requestSourceIP = websocketReqeuest.RequestContext.Identity.SourceIP + } + + return requestSourceIP +} + // endregion Helpers