Skip to content
Snippets Groups Projects
Select Git revision
  • 90ec162f4b59c68672a981f51916632a94e95b1c
  • main default protected
  • v1.298.0
  • v1.297.0
  • v1.296.0
  • v1.295.0
  • v1.294.0
  • v1.293.0
  • v1.292.0
  • v1.291.0
  • v1.290.0
  • v1.289.0
  • v1.288.0
  • v1.287.0
  • v1.286.0
  • v1.285.0
  • v1.284.0
  • v1.283.0
  • v1.282.0
  • v1.281.0
  • v1.280.0
  • v1.279.0
22 results

context.go

Blame
  • context.go 7.53 KiB
    package api
    
    import (
    	"context"
    	"encoding/json"
    	"fmt"
    	"reflect"
    	"strconv"
    	"strings"
    	"time"
    
    	"github.com/aws/aws-lambda-go/events"
    	"github.com/uptrace/bun"
    	"gitlab.com/uafrica/go-utils/errors"
    	"gitlab.com/uafrica/go-utils/logger"
    	"gitlab.com/uafrica/go-utils/queues"
    	"gitlab.com/uafrica/go-utils/service"
    )
    
    type IContext interface {
    	context.Context
    	logger.ILogger
    	queues.IProducer
    	StartTime() time.Time
    	MillisecondsSinceStart() int64
    	CheckValues(checkName string) interface{}
    	CheckValue(checkName, valueName string) interface{}
    }
    
    type Context struct {
    	service.Context
    	queues.IProducer
    	Request          events.APIGatewayProxyRequest
    	RequestID        string
    	ValuesFromChecks map[string]map[string]interface{} //also in context.Value(), but cannot retrieve iteratively from there for logging...
    	DB               *bun.DB
    }
    
    func (ctx Context) CheckValues(checkName string) interface{} {
    	if cv, ok := ctx.ValuesFromChecks[checkName]; ok {
    		return cv
    	}
    	return nil
    }
    
    func (ctx Context) CheckValue(checkName, valueName string) interface{} {
    	if cv, ok := ctx.ValuesFromChecks[checkName]; ok {
    		if v, ok := cv[valueName]; ok {
    			return v
    		}
    	}
    	return nil
    }
    
    // func (ctx Context) Audit(org, new interface{}, eventType types.AuditEventType) {
    // 	//call old function for now - should become part of context ONLY
    // 	audit.SaveAuditEvent(org, new, ctx.Claims, eventType, &ctx.RequestID)
    // }
    
    //todo: change to be a ctx method that defer to log so it does not have to be called explicitly
    //it should also capture metrics for the handler and automaticlaly write the audit record,
    //(but still allow for audit to be suppressed may be in some cases)
    func (ctx Context) LogAPIRequestAndResponse(res events.APIGatewayProxyResponse, err error) {
    	fields := map[string]interface{}{
    		"path":                   ctx.Request.Path,
    		"method":                 ctx.Request.HTTPMethod,
    		"status_code":            res.StatusCode,
    		"api_gateway_request_id": ctx.RequestID,
    	}
    
    	if ctx.Request.HTTPMethod == "GET" {
    		fields["req-query"] = ctx.Request.QueryStringParameters
    	}
    
    	statusOK := res.StatusCode >= 200 && res.StatusCode <= 299
    	if err != nil || !statusOK {
    		fields["error"] = err
    		fields["req-body"] = ctx.Request.Body
    		fields["req-query"] = ctx.Request.QueryStringParameters
    		fields["res-body"] = res.Body
    		for checkName, checkValues := range ctx.ValuesFromChecks {
    			for name, value := range checkValues {
    				fields[checkName+"_"+name] = value
    			}
    		}
    	}
    	ctx.WithFields(fields).Infof("Request & Response: err=%+v", err)
    }
    
    //allocate struct for params, populate it from the URL parameters then validate and return the struct
    func (ctx Context) GetRequestParams(paramsStructType reflect.Type) (interface{}, error) {
    	paramValues := map[string]interface{}{}
    	for n, v := range ctx.Request.QueryStringParameters {
    		paramValues[n] = v
    	}
    	paramsStructValuePtr := reflect.New(paramsStructType)
    	for i := 0; i < paramsStructType.NumField(); i++ {
    		f := paramsStructType.Field(i)
    		n := (strings.SplitN(f.Tag.Get("json"), ",", 2))[0]
    		if n == "" {
    			n = strings.ToLower(f.Name)
    		}
    		if n == "" || n == "-" {
    			continue
    		}
    
    		//get value(s) from query string
    		var paramStrValues []string
    		if paramStrValue, isDefined := ctx.Request.QueryStringParameters[n]; isDefined {
    			paramStrValues = []string{paramStrValue} //single value
    		} else {
    			paramStrValues = ctx.Request.MultiValueQueryStringParameters[n]
    		}
    		if len(paramStrValues) == 0 {
    			continue //param has no value specified in URL
    		}
    
    		//param is defined >=1 times in URL
    		if f.Type.Kind() == reflect.Slice {
    			//iterate over all specified values
    			for index, paramStrValue := range paramStrValues {
    				newValuePtr := reflect.New(f.Type.Elem())
    				if err := setParamFromStr(fmt.Sprintf("%s[%d]", n, index),
    					//paramsStructValuePtr.Elem().Field(i).Index(index),
    					newValuePtr.Elem(), //.Elem() to dereference
    					paramStrValue); err != nil {
    					return nil, errors.Wrapf(err, "failed to set %s[%d]=%s", n, i, paramStrValues[0])
    				}
    				paramsStructValuePtr.Elem().Field(i).Set(reflect.Append(paramsStructValuePtr.Elem().Field(i), newValuePtr.Elem()))
    			}
    		} else {
    			if len(paramStrValues) > 1 {
    				return nil, errors.Errorf("%s does not support >1 values(%v)", n, strings.Join(paramStrValues, ","))
    			}
    			//single value specified
    			if err := setParamFromStr(n, paramsStructValuePtr.Elem().Field(i), paramStrValues[0]); err != nil {
    				return nil, errors.Wrapf(err, "failed to set %s=%s", n, paramStrValues[0])
    			}
    		}
    	} //for each param struct field
    
    	if validator, ok := paramsStructValuePtr.Interface().(IValidator); ok {
    		if err := validator.Validate(); err != nil {
    			return nil, errors.Wrapf(err, "invalid params")
    		}
    	}
    
    	return paramsStructValuePtr.Elem().Interface(), nil
    }
    
    func setParamFromStr(fieldName string, fieldValue reflect.Value, paramStrValue string) error {
    	logger.Debugf("Set(%s,%v,%v,%v)", fieldName, fieldValue.Type(), fieldValue.Kind(), paramStrValue)
    	switch fieldValue.Type().Kind() {
    	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    		//parse to int for this struct field
    		if i64, err := strconv.ParseInt(paramStrValue, 10, 64); err != nil {
    			return errors.Errorf("%s is not a number", paramStrValue)
    		} else {
    			switch fieldValue.Type().Kind() {
    			case reflect.Int:
    				fieldValue.Set(reflect.ValueOf(int(i64)))
    			case reflect.Int8:
    				fieldValue.Set(reflect.ValueOf(int8(i64)))
    			case reflect.Int16:
    				fieldValue.Set(reflect.ValueOf(int16(i64)))
    			case reflect.Int32:
    				fieldValue.Set(reflect.ValueOf(int32(i64)))
    			case reflect.Int64:
    				fieldValue.Set(reflect.ValueOf(i64))
    			}
    		}
    
    	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
    		//parse to int for this struct field
    		if u64, err := strconv.ParseUint(paramStrValue, 10, 64); err != nil {
    			return errors.Errorf("%s is not a number", paramStrValue)
    		} else {
    			switch fieldValue.Type().Kind() {
    			case reflect.Uint:
    				fieldValue.Set(reflect.ValueOf(uint(u64)))
    			case reflect.Uint8:
    				fieldValue.Set(reflect.ValueOf(uint8(u64)))
    			case reflect.Uint16:
    				fieldValue.Set(reflect.ValueOf(uint16(u64)))
    			case reflect.Uint32:
    				fieldValue.Set(reflect.ValueOf(uint32(u64)))
    			case reflect.Uint64:
    				fieldValue.Set(reflect.ValueOf(u64))
    			}
    		}
    
    	case reflect.Bool:
    		bs := strings.ToLower(paramStrValue)
    		if bs == "true" || bs == "yes" || bs == "1" {
    			fieldValue.Set(reflect.ValueOf(true))
    		}
    
    	case reflect.String:
    		fieldValue.Set(reflect.ValueOf(paramStrValue))
    
    	case reflect.Float32:
    		if f32, err := strconv.ParseFloat(paramStrValue, 32); err != nil {
    			return errors.Wrapf(err, "invalid float")
    		} else {
    			fieldValue.Set(reflect.ValueOf(float32(f32)))
    		}
    
    	case reflect.Float64:
    		if f64, err := strconv.ParseFloat(paramStrValue, 64); err != nil {
    			return errors.Wrapf(err, "invalid float")
    		} else {
    			fieldValue.Set(reflect.ValueOf(f64))
    		}
    
    	default:
    		return errors.Errorf("unsupported type %v", fieldValue.Type().Kind())
    	} //switch param struct field
    	return nil
    }
    
    func (ctx Context) GetRequestBody(requestStructType reflect.Type) (interface{}, error) {
    	requestStructValuePtr := reflect.New(requestStructType)
    	err := json.Unmarshal([]byte(ctx.Request.Body), requestStructValuePtr.Interface())
    	if err != nil {
    		return nil, errors.Wrapf(err, "failed to JSON request body")
    	}
    
    	if validator, ok := requestStructValuePtr.Interface().(IValidator); ok {
    		if err := validator.Validate(); err != nil {
    			return nil, errors.Wrapf(err, "invalid request body")
    		}
    	}
    
    	return requestStructValuePtr.Elem().Interface(), nil
    }
    
    type IValidator interface {
    	Validate() error
    }