From 8a2ecaefea7c276cc62fed11bb37f6efb38779af Mon Sep 17 00:00:00 2001
From: jano3 <jano@bob.co.za>
Date: Wed, 24 Jan 2024 14:25:12 +0200
Subject: [PATCH] Add options to disable logs or output logs to buffer

---
 logs/logs.go | 126 +++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 126 insertions(+)

diff --git a/logs/logs.go b/logs/logs.go
index 0cb605d..0106d6f 100644
--- a/logs/logs.go
+++ b/logs/logs.go
@@ -7,6 +7,7 @@ import (
 	"gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/errors"
 	"gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/string_utils"
 	"gitlab.bob.co.za/bob-public-utils/bobgroup-go-utils/utils"
+	"io"
 	"net/http"
 	"net/url"
 	"os"
@@ -21,11 +22,19 @@ import (
 	log "github.com/sirupsen/logrus"
 )
 
+type LogBufferWithLevel struct {
+	LogBuffer *bytes.Buffer
+	LogLevel  log.Level
+}
+
 var logger *log.Entry
+var logBuffers []*LogBufferWithLevel
 
 var apiRequest *events.APIGatewayProxyRequest
 var currentRequestID *string
 var isDebug = false
+var logToBuffer = false
+var disableLogging = false
 var build string
 var raygunClient *raygun4go.Client
 
@@ -165,25 +174,50 @@ func getLogger() *log.Entry {
 }
 
 func InfoWithFields(fields map[string]interface{}, message interface{}) {
+	if disableLogging {
+		return
+	}
+
 	if reflect.TypeOf(message).Kind() == reflect.String {
 		message = SanitiseLogs(message.(string))
 	}
 	sanitisedFields := SanitiseFields(fields)
+
+	logBuffer := CheckToGetLogBuffer()
 	getLogger().WithFields(sanitisedFields).Info(message)
+	CheckToStoreLogBuffer(logBuffer, log.InfoLevel)
 }
 
 func Info(format string, a ...interface{}) {
+	if disableLogging {
+		return
+	}
+
 	message := SanitiseLogs(fmt.Sprintf(format, a...))
+
+	logBuffer := CheckToGetLogBuffer()
 	getLogger().Info(message)
+	CheckToStoreLogBuffer(logBuffer, log.InfoLevel)
 }
 
 func ErrorWithFields(fields map[string]interface{}, err error) {
+	if disableLogging {
+		return
+	}
+
 	sanitisedFields := SanitiseFields(fields)
 	sendRaygunError(sanitisedFields, err)
+
+	logBuffer := CheckToGetLogBuffer()
 	getLogger().WithFields(sanitisedFields).Error(err)
+	CheckToStoreLogBuffer(logBuffer, log.ErrorLevel)
 }
 
 func ErrorWithMsg(message string, err error) {
+	if disableLogging {
+		return
+	}
+
 	if err == nil {
 		err = errors.Error(message)
 	}
@@ -193,32 +227,63 @@ func ErrorWithMsg(message string, err error) {
 }
 
 func ErrorMsg(message string) {
+	if disableLogging {
+		return
+	}
+
 	ErrorWithMsg(message, nil)
 }
 
 func Warn(format string, a ...interface{}) {
+	if disableLogging {
+		return
+	}
+
 	message := SanitiseLogs(fmt.Sprintf(format, a...))
+
+	logBuffer := CheckToGetLogBuffer()
 	getLogger().Warn(message)
+	CheckToStoreLogBuffer(logBuffer, log.WarnLevel)
 }
 
 func WarnWithFields(fields map[string]interface{}, err error) {
+	if disableLogging {
+		return
+	}
+
 	sanitisedFields := SanitiseFields(fields)
+	logBuffer := CheckToGetLogBuffer()
 	getLogger().WithFields(sanitisedFields).Warn(err)
+	CheckToStoreLogBuffer(logBuffer, log.WarnLevel)
 }
 
 func SQLDebugInfo(sql string) {
+	if disableLogging {
+		return
+	}
+
+	logBuffer := CheckToGetLogBuffer()
 	getLogger().WithFields(map[string]interface{}{
 		"sql": sql,
 	}).Debug("SQL query")
+	CheckToStoreLogBuffer(logBuffer, log.InfoLevel)
 }
 
 func LogShipmentID(id int64) {
+	if disableLogging {
+		return
+	}
+
 	InfoWithFields(map[string]interface{}{
 		"shipment_id": id,
 	}, "Current-shipment-ID")
 }
 
 func LogRequestInfo(req events.APIGatewayProxyRequest, shouldExcludeBody bool, extraFields map[string]interface{}) {
+	if disableLogging {
+		return
+	}
+
 	fields := map[string]interface{}{
 		"path":   req.Path,
 		"method": req.HTTPMethod,
@@ -246,6 +311,10 @@ func LogRequestInfo(req events.APIGatewayProxyRequest, shouldExcludeBody bool, e
 }
 
 func LogResponseInfo(req events.APIGatewayProxyRequest, res events.APIGatewayProxyResponse, err error) {
+	if disableLogging {
+		return
+	}
+
 	fields := map[string]interface{}{
 		"status_code": res.StatusCode,
 	}
@@ -262,10 +331,20 @@ func LogResponseInfo(req events.APIGatewayProxyRequest, res events.APIGatewayPro
 }
 
 func LogApiAudit(fields log.Fields) {
+	if disableLogging {
+		return
+	}
+
+	logBuffer := CheckToGetLogBuffer()
 	getLogger().WithFields(fields).Info("api-audit-log")
+	CheckToStoreLogBuffer(logBuffer, log.InfoLevel)
 }
 
 func LogSQSEvent(event events.SQSEvent) {
+	if disableLogging {
+		return
+	}
+
 	sqsReducedEvents := []map[string]string{}
 
 	for _, record := range event.Records {
@@ -288,10 +367,22 @@ func LogSQSEvent(event events.SQSEvent) {
 	}, "")
 }
 
+func SetOutput(out io.Writer) {
+	log.SetOutput(out)
+}
+
 func SetOutputToFile(file *os.File) {
 	log.SetOutput(file)
 }
 
+func SetOutputToBuffer(outputToLogBuffer bool) {
+	logToBuffer = outputToLogBuffer
+}
+
+func DisableLogging() {
+	disableLogging = true
+}
+
 func ClearInfo() {
 	logger = nil
 }
@@ -402,3 +493,38 @@ func (f *CustomLogFormatter) Format(entry *log.Entry) ([]byte, error) {
 
 	return b.Bytes(), nil
 }
+
+func CheckToGetLogBuffer() *bytes.Buffer {
+	if logToBuffer {
+		logBuffer := &bytes.Buffer{}
+		log.SetOutput(logBuffer)
+		return logBuffer
+	}
+	return nil
+}
+
+func CheckToStoreLogBuffer(logBuffer *bytes.Buffer, logLevel log.Level) {
+	if logBuffer != nil {
+		logBuffers = append(logBuffers, &LogBufferWithLevel{
+			LogBuffer: logBuffer,
+			LogLevel:  logLevel,
+		})
+	}
+}
+
+func LogAllLogBuffers() {
+	log.SetOutput(os.Stderr)
+	for _, logBuffer := range logBuffers {
+		switch logBuffer.LogLevel {
+		case log.InfoLevel:
+			getLogger().Info(logBuffer.LogBuffer.String())
+		case log.ErrorLevel:
+			getLogger().Error(logBuffer.LogBuffer.String())
+		case log.WarnLevel:
+			getLogger().Warn(logBuffer.LogBuffer.String())
+		case log.DebugLevel:
+			getLogger().Debug(logBuffer.LogBuffer.String())
+
+		}
+	}
+}
-- 
GitLab