Skip to content
Snippets Groups Projects
Select Git revision
  • 5a132b451fcb662227b36b51d0d417974ac60490
  • main default protected
  • v1.298.0
  • v1.297.0
  • v1.296.0
  • v1.295.0
  • v1.294.0
  • v1.293.0
  • v1.292.0
  • v1.291.0
  • v1.290.0
  • v1.289.0
  • v1.288.0
  • v1.287.0
  • v1.286.0
  • v1.285.0
  • v1.284.0
  • v1.283.0
  • v1.282.0
  • v1.281.0
  • v1.280.0
  • v1.279.0
22 results

audit.go

Blame
  • api.go 5.73 KiB
    package api
    
    import (
    	"fmt"
    	"net/http"
    	"runtime/debug"
    	"sync"
    
    	"github.com/aws/aws-lambda-go/events"
    	"github.com/aws/aws-lambda-go/lambda"
    	"gitlab.com/uafrica/go-utils/audit"
    	"gitlab.com/uafrica/go-utils/errors"
    	"gitlab.com/uafrica/go-utils/logger"
    	queues_mem "gitlab.com/uafrica/go-utils/queues/mem"
    	queues_sqs "gitlab.com/uafrica/go-utils/queues/sqs"
    	"gitlab.com/uafrica/go-utils/service"
    	"gitlab.com/uafrica/go-utils/string_utils"
    )
    
    //LEGACY: global variable is set only for backward compatibility
    //When handlers are changed to accept context, they should get this from the context
    var CurrentRequestID *string
    
    //New creates the API with the specified routes keys on [path][method]
    //value could be any of the handler function signatures supported by the api.Router
    //requestIDHeaderKey is defined in the response header to match the requestID from the request
    func New(requestIDHeaderKey string, routes map[string]map[string]interface{}) Api {
    	if requestIDHeaderKey == "" {
    		requestIDHeaderKey = "request-id"
    	}
    
    	router, err := NewRouter(routes)
    	if err != nil {
    		panic(fmt.Sprintf("cannot create router: %+v", err))
    	}
    
    	return Api{
    		Service:                 service.New(),
    		router:                  router,
    		requestIDHeaderKey:      requestIDHeaderKey,
    		checks:                  map[string]ICheck{},
    		crashReporter:           defaultCrashReporter{},
    		cors:                    nil,
    		localPort:               0,
    		localQueueEventHandlers: nil,
    	}
    }
    
    type Api struct {
    	service.Service
    	router                  Router
    	requestIDHeaderKey      string
    	checks                  map[string]ICheck
    	crashReporter           ICrashReporter
    	cors                    ICORS
    	localPort               int                    //==0 for default lambda, >0 for http.ListenAndServe to run locally
    	localQueueEventHandlers map[string]interface{} //only applies when running locally for local in-memory queues
    }
    
    //wrap Service.WithStarter to return api, else cannot be chained
    func (api Api) WithStarter(name string, starter service.IStarter) Api {
    	api.Service = api.Service.WithStarter(name, starter)
    	return api
    }
    
    //wrap Service.WithErrorReporter to return api, else cannot be chained
    func (api Api) WithErrorReporter(reporter service.IErrorReporter) Api {
    	api.Service = api.Service.WithErrorReporter(reporter)
    	return api
    }
    
    //wrap else cannot be chained
    func (api Api) WithAuditor(auditor audit.Auditor) Api {
    	api.Service = api.Service.WithAuditor(auditor)
    	return api
    }
    
    //wrap else cannot be chained
    func (api Api) WithProducer(producer service.Producer) Api {
    	api.Service = api.Service.WithProducer(producer)
    	return api
    }
    
    //add a check to startup of each context
    //they will be called in the sequence they were added
    //if check return error, processing stops and err is returned
    //if check succeed, and return !=nil data, it is stored against the name
    //		so your handler can retieve it with:
    //			checkData := ctx.Value(name).(expectedType)
    //		or
    //			checkData,ok := ctx.Value(name).(expectedType)
    //			if !ok { ... }
    //you can implement one check that does everything and return a struct or
    //implement one for your db, one for rate limit, one for auth, one for ...
    //the name must be snake-case, e.g. "this_is_my_check_data_name"
    func (api Api) WithCheck(name string, check ICheck) Api {
    	if !string_utils.IsSnakeCase(name) {
    		panic(errors.Errorf("invalid check name=\"%s\", expecting snake_case names only", name))
    	}
    	if check == nil {
    		panic(errors.Errorf("check(%s) func==nil", name))
    	}
    	if _, ok := api.checks[name]; ok {
    		panic(errors.Errorf("check(%s) already defined", name))
    	}
    	api.checks[name] = check
    	return api
    }
    
    func (api Api) WithCORS(cors ICORS) Api {
    	if cors != nil {
    		api.cors = cors
    	}
    	return api
    }
    
    func (api Api) WithCrashReported(crashReporter ICrashReporter) Api {
    	if crashReporter != nil {
    		api.crashReporter = crashReporter
    	}
    	return api
    }
    
    func (api Api) WithLocalPort(localPortPtr *int, eventHandlers map[string]interface{}) Api {
    	if api.localPort != 0 {
    		panic("local port already defined")
    	}
    	if localPortPtr != nil && *localPortPtr > 0 {
    		api.localPort = *localPortPtr
    		api.localQueueEventHandlers = eventHandlers
    	}
    	return api
    }
    
    //run and panic on error
    func (api Api) Run() {
    	//decide local of lambda
    	if api.localPort > 0 {
    		//running locally with standard HTTP server
    
    		if api.localQueueEventHandlers != nil {
    			//when running locally - we want to send and process SQS events locally using channels
    			//here we create a SQS chan and start listening to it
    			//again: this is quick hack... will make this part of framework once it works well
    			api.Debugf("Creating local queue consumer/producer...")
    			memConsumer := queues_mem.NewConsumer(api.localQueueEventHandlers)
    			api = api.WithProducer(queues_mem.NewProducer(memConsumer))
    
    			sqsEventChan := make(chan events.SQSEvent)
    			sqsWaitGroup := sync.WaitGroup{}
    			sqsWaitGroup.Add(1)
    			go func() {
    				for event := range sqsEventChan {
    					logger.Debugf("NOT YET PROCESSING SQS Event: %+v", event)
    				}
    				sqsWaitGroup.Done()
    			}()
    
    			//when we terminate, close the sqs chan and wait for it to complete processing
    			defer func() {
    				close(sqsEventChan)
    				sqsWaitGroup.Wait()
    			}()
    		} else {
    			//use SQS for events
    			api = api.WithProducer(queues_sqs.NewProducer(api.requestIDHeaderKey))
    		}
    
    		err := http.ListenAndServe(fmt.Sprintf(":%d", api.localPort), api) //calls api.ServeHTTP() which calls api.Handler()
    		if err != nil {
    			panic(err)
    		}
    	} else {
    		api = api.WithProducer(queues_sqs.NewProducer(api.requestIDHeaderKey))
    		lambda.Start(api.Handler) //calls api.Handler directly
    	}
    }
    
    type defaultCrashReporter struct{}
    
    func (defaultCrashReporter) Catch(ctx Context) {
    	crashErr := recover()
    	if crashErr != nil {
    		ctx.Errorf("crashed: %v, with stack: %s", crashErr, string(debug.Stack()))
    	}
    }