From a0455aeef58d73441e2951675f430ffe53dbde6e Mon Sep 17 00:00:00 2001
From: Jan Semmelink <jan@uafrica.com>
Date: Thu, 14 Oct 2021 15:19:17 +0200
Subject: [PATCH] Support parsing of URL params into nested params structs and
 arrays when specified same param multiple times. Not (yet) parsing id=[1,2,3]
 into array

---
 api/api.go         |   2 +-
 api/context.go     |  61 ++++++++++++++--------
 api/handler.go     |  20 ++++++--
 api/lambda.go      |  33 +++++++-----
 api/params_test.go | 124 +++++++++++++++++++++++++++++++++++++++++++++
 5 files changed, 199 insertions(+), 41 deletions(-)
 create mode 100644 api/params_test.go

diff --git a/api/api.go b/api/api.go
index cd105cb..7b7bc9d 100644
--- a/api/api.go
+++ b/api/api.go
@@ -148,7 +148,7 @@ func (api Api) WithEvents(eventHandlers map[string]interface{}) Api {
 //run and panic on error
 func (api Api) Run() {
 	//decide local or SQS
-	if os.Getenv("LOG_LEVEL") == "debug" && api.localQueueEventHandlers != nil {
+	if (api.localPort > 0 || os.Getenv("LOG_LEVEL") == "debug") && api.localQueueEventHandlers != nil {
 		//use in-memory channels for async events
 		api.Debugf("Using in-memory channels for async events ...")
 		memConsumer := queues_mem.NewConsumer(api.localQueueEventHandlers)
diff --git a/api/context.go b/api/context.go
index f031af7..c559326 100644
--- a/api/context.go
+++ b/api/context.go
@@ -14,6 +14,9 @@ import (
 type Context interface {
 	service.Context
 	Request() events.APIGatewayProxyRequest
+	GetRequestParams(paramsStructType reflect.Type) (interface{}, error)
+	GetRequestBody(requestStructType reflect.Type) (interface{}, error)
+	LogAPIRequestAndResponse(res events.APIGatewayProxyResponse, err error)
 }
 
 var contextInterfaceType = reflect.TypeOf((*Context)(nil)).Elem()
@@ -59,8 +62,33 @@ func (ctx apiContext) GetRequestParams(paramsStructType reflect.Type) (interface
 		paramValues[n] = v
 	}
 	paramsStructValuePtr := reflect.New(paramsStructType)
-	for i := 0; i < paramsStructType.NumField(); i++ {
-		f := paramsStructType.Field(i)
+
+	if err := ctx.extract("params", paramsStructType, paramsStructValuePtr.Elem()); err != nil {
+		return nil, errors.Wrapf(err, "failed to put query param values into struct")
+	}
+	if err := ctx.applyClaim("params", paramsStructValuePtr.Interface()); err != nil {
+		return nil, errors.Wrapf(err, "failed to fill claims on params")
+	}
+	if validator, ok := paramsStructValuePtr.Interface().(IValidator); ok {
+		if err := validator.Validate(); err != nil {
+			return nil, errors.Wrapf(err, "invalid params")
+		}
+	}
+	return paramsStructValuePtr.Elem().Interface(), nil
+}
+
+func (ctx apiContext) extract(name string, t reflect.Type, v reflect.Value) error {
+	for i := 0; i < t.NumField(); i++ {
+		f := t.Field(i)
+		switch f.Type.Kind() {
+		case reflect.Struct:
+			if err := ctx.extract(name+"."+f.Name, t.Field(i).Type, v.Field(i)); err != nil {
+				return errors.Wrapf(err, "failed to fill sub %s.%s", name, f.Name)
+			}
+			continue
+		default:
+		}
+
 		n := (strings.SplitN(f.Tag.Get("json"), ",", 2))[0]
 		if n == "" {
 			n = strings.ToLower(f.Name)
@@ -86,32 +114,21 @@ func (ctx apiContext) GetRequestParams(paramsStructType reflect.Type) (interface
 			for index, paramStrValue := range paramStrValues {
 				newValuePtr := reflect.New(f.Type.Elem())
 				if err := reflection.SetValue(newValuePtr.Elem(), paramStrValue); err != nil {
-					return nil, errors.Wrapf(err, "failed to set %s[%d]=%s", n, index, paramStrValues[0])
+					return errors.Wrapf(err, "failed to set %s[%d]=%s", n, index, paramStrValues[0])
 				}
-				paramsStructValuePtr.Elem().Field(i).Set(reflect.Append(paramsStructValuePtr.Elem().Field(i), newValuePtr.Elem()))
+				v.Field(i).Set(reflect.Append(v.Field(i), newValuePtr.Elem()))
 			}
 		} else {
 			if len(paramStrValues) > 1 {
-				return nil, errors.Errorf("%s does not support >1 values(%v)", n, strings.Join(paramStrValues, ","))
+				return errors.Errorf("%s does not support >1 values(%v)", n, strings.Join(paramStrValues, ","))
 			}
 			//single value specified
-			if err := reflection.SetValue(paramsStructValuePtr.Elem().Field(i), paramStrValues[0]); err != nil {
-				return nil, errors.Wrapf(err, "failed to set %s=%s", n, paramStrValues[0])
+			if err := reflection.SetValue(v.Field(i), paramStrValues[0]); err != nil {
+				return errors.Wrapf(err, "failed to set %s=%s", n, paramStrValues[0])
 			}
 		}
 	} //for each param struct field
-
-	if err := ctx.applyClaim("params", paramsStructValuePtr.Interface()); err != nil {
-		return nil, errors.Wrapf(err, "failed to fill claims on params")
-	}
-
-	if validator, ok := paramsStructValuePtr.Interface().(IValidator); ok {
-		if err := validator.Validate(); err != nil {
-			return nil, errors.Wrapf(err, "invalid params")
-		}
-	}
-
-	return paramsStructValuePtr.Elem().Interface(), nil
+	return nil
 }
 
 func (ctx apiContext) GetRequestBody(requestStructType reflect.Type) (interface{}, error) {
@@ -160,9 +177,9 @@ func (ctx *apiContext) setClaim(name string, structType reflect.Type, structValu
 				return errors.Errorf("failed to set %s.%s=(%T)%v", structType.Name(), fieldName, claimValue, claimValue)
 			}
 			ctx.Debugf("defined claim %s.%s=(%T)%v ...", name, fieldName, claimValue, claimValue)
-		} /* else {
-			ctx.Debugf("claim(%s) does not apply to %s", fieldName, structType.Name())
-		}*/
+			// } else {
+			// 	ctx.Debugf("claim(%s) does not apply to %s", fieldName, structType.Name())
+		}
 	}
 
 	//recurse into sub-structs and sub struct ptrs (not yet slices)
diff --git a/api/handler.go b/api/handler.go
index 3aa1ceb..a258079 100644
--- a/api/handler.go
+++ b/api/handler.go
@@ -35,7 +35,7 @@ func NewHandler(fnc interface{}) (handler, error) {
 	//arg[1] must be a struct for params. It may be an empty struct, but
 	//all public fields require a json tag which we will use to math the URL param name
 	if err := validateStructType(fncType.In(1)); err != nil {
-		return h, errors.Errorf("second arg %v is not valid params struct type", fncType.In(1))
+		return h, errors.Wrapf(err, "second arg %v is not valid params struct type", fncType.In(1))
 	}
 	h.RequestParamsType = fncType.In(1)
 
@@ -87,9 +87,19 @@ func validateStructType(t reflect.Type) error {
 	if t.Kind() != reflect.Struct {
 		return errors.Errorf("%v is %v, not a struct", t, t.Kind())
 	}
-	// for i := 0; i < t.NumField(); i++ {
-	// 	f := t.Field(i)
-	// 	if f.... check tags recursively... for now, not too strict ... add checks if we see issues that break the API, to help dev to fix before we deploy, or to prevent bad habits...
-	// }
+	for i := 0; i < t.NumField(); i++ {
+		f := t.Field(i)
+		if f.Name[0] >= 'a' && f.Name[0] <= 'z' {
+			//lowercase fields should not have json tag
+			if f.Tag.Get("json") != "" {
+				return errors.Errorf("%s.%s must be uppercase because it has a json tag \"%s\"",
+					t.Name(),
+					f.Name,
+					f.Tag.Get("json"))
+			}
+		}
+
+		// 	if f.... check tags recursively... for now, not too strict ... add checks if we see issues that break the API, to help dev to fix before we deploy, or to prevent bad habits...
+	}
 	return nil
 }
diff --git a/api/lambda.go b/api/lambda.go
index f4fe027..b2aeaed 100644
--- a/api/lambda.go
+++ b/api/lambda.go
@@ -16,6 +16,18 @@ import (
 	"gitlab.com/uafrica/go-utils/logger"
 )
 
+func (api Api) NewContext(baseCtx context.Context, requestID string, request events.APIGatewayProxyRequest) (Context, error) {
+	serviceContext, err := api.Service.NewContext(baseCtx, requestID, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	return &apiContext{
+		Context: serviceContext,
+		request: request,
+	}, nil
+}
+
 //this is native handler for lambda passed into lambda.Start()
 //to run locally, this is called from app.ServeHTTP()
 func (api Api) Handler(baseCtx context.Context, apiGatewayProxyReq events.APIGatewayProxyRequest) (res events.APIGatewayProxyResponse, err error) {
@@ -41,16 +53,11 @@ func (api Api) Handler(baseCtx context.Context, apiGatewayProxyReq events.APIGat
 	}
 
 	//service context invoke the starters and could fail, e.g. if cannot connect to db
-	serviceContext, err := api.Service.NewContext(baseCtx, requestID, nil)
+	ctx, err := api.NewContext(baseCtx, requestID, apiGatewayProxyReq)
 	if err != nil {
 		return res, err
 	}
 
-	ctx := &apiContext{
-		Context: serviceContext,
-		request: apiGatewayProxyReq,
-	}
-
 	//report handler crashes
 	if api.crashReporter != nil {
 		defer api.crashReporter.Catch(ctx)
@@ -94,7 +101,7 @@ func (api Api) Handler(baseCtx context.Context, apiGatewayProxyReq events.APIGat
 		}
 		if err := api.Service.WriteValues(ctx.StartTime(), time.Now(), ctx.RequestID(), map[string]interface{}{
 			"request_id": ctx.RequestID(),
-			"request":    ctx.request,
+			"request":    ctx.Request(),
 			"response":   res},
 		); err != nil {
 			ctx.Errorf("failed to audit: %+v", err)
@@ -125,15 +132,15 @@ func (api Api) Handler(baseCtx context.Context, apiGatewayProxyReq events.APIGat
 
 	//LEGACY: delete this as soon as all handlers accepts context
 	//this does not support concurrent execution!
-	CurrentRequestID = &ctx.request.RequestContext.RequestID
+	CurrentRequestID = &apiGatewayProxyReq.RequestContext.RequestID
 
 	ctx.Debugf("HTTP %s %s ...\n", apiGatewayProxyReq.HTTPMethod, apiGatewayProxyReq.Resource)
 	ctx.WithFields(map[string]interface{}{
-		"http_method":                ctx.request.HTTPMethod,
-		"path":                       ctx.request.Path,
-		"api_gateway_request_id":     ctx.request.RequestContext.RequestID,
-		"user_cognito_auth_provider": ctx.request.RequestContext.Identity.CognitoAuthenticationProvider,
-		"user_arn":                   ctx.request.RequestContext.Identity.UserArn,
+		"http_method":                ctx.Request().HTTPMethod,
+		"path":                       ctx.Request().Path,
+		"api_gateway_request_id":     ctx.Request().RequestContext.RequestID,
+		"user_cognito_auth_provider": ctx.Request().RequestContext.Identity.CognitoAuthenticationProvider,
+		"user_arn":                   ctx.Request().RequestContext.Identity.UserArn,
 	}).Infof("Start API Handler")
 
 	//TODO:
diff --git a/api/params_test.go b/api/params_test.go
new file mode 100644
index 0000000..ae0a415
--- /dev/null
+++ b/api/params_test.go
@@ -0,0 +1,124 @@
+package api_test
+
+import (
+	"context"
+	"reflect"
+	"testing"
+
+	"github.com/aws/aws-lambda-go/events"
+	"gitlab.com/uafrica/go-utils/api"
+	"gitlab.com/uafrica/go-utils/logger"
+)
+
+type P1 struct {
+	A int `json:"a"`
+}
+
+type P2 struct {
+	P1       //nested struct must be filled
+	B  int   `json:"b"`
+	F  []int `json:"f"`
+}
+
+type P3 struct {
+	P2       //nessted struct must be filled
+	C  int   `json:"c"`
+	E  []int `json:"e"`
+}
+
+func TestNested(t *testing.T) {
+	logger.SetGlobalLevel(logger.LevelDebug)
+	logger.SetGlobalFormat(logger.NewConsole())
+	var ctx api.Context
+	var err error
+	ctx, err = api.New("request-id", nil).NewContext(
+		context.Background(),
+		"123",
+		events.APIGatewayProxyRequest{
+			QueryStringParameters: map[string]string{
+				"a": "1", //must be written into P3.P2.P1.A
+				"b": "2", //must be written into P3.P2.B
+				"c": "3", //must be written into P3.C
+				"d": "4", //ignored because no field tagged "d"
+			},
+			MultiValueQueryStringParameters: map[string][]string{
+				"e": {"5", "6", "7"}, //filled into P3.E as []string {"5", "6", "7"}
+				"f": {"6", "7", "8"}, //filled into P2 as []string {"6", "7", "8"}
+			},
+		})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if p3d, err := ctx.GetRequestParams(reflect.TypeOf(P3{})); err != nil {
+		t.Fatal(err)
+	} else {
+		p3 := p3d.(P3)
+		t.Logf("p3: %+v", p3)
+		if p3.C != 3 || p3.B != 2 || p3.A != 1 {
+			t.Fatalf("wrong values")
+		}
+		if len(p3.E) != 3 || p3.E[0] != 5 || p3.E[1] != 6 || p3.E[2] != 7 {
+			t.Fatalf("wrong values")
+		}
+		if len(p3.F) != 3 || p3.F[0] != 6 || p3.F[1] != 7 || p3.F[2] != 8 {
+			t.Fatalf("wrong values")
+		}
+	}
+}
+
+type PageParams struct {
+	Limit  int64 `json:"limit"`
+	Offset int64 `json:"offset"`
+}
+
+type GetParams struct {
+	PageParams
+	ID int64 `json:"id"`
+}
+
+type MyGetParams struct {
+	GetParams
+	Search string   `json:"search"`
+	Find   string   `json:"find"`
+	Find1  []string `json:"find1"`
+	Find2  []string `json:"find2"`
+}
+
+func TestGet(t *testing.T) {
+	logger.SetGlobalLevel(logger.LevelDebug)
+	logger.SetGlobalFormat(logger.NewConsole())
+	var ctx api.Context
+	var err error
+	ctx, err = api.New("request-id", nil).NewContext(
+		context.Background(),
+		"123",
+		events.APIGatewayProxyRequest{
+			QueryStringParameters: map[string]string{
+				"id":     "1",
+				"limit":  "2",
+				"offset": "3",
+				"search": "4",     //single value parts into string
+				"find1":  "sarel", //single value parsed into array
+			},
+			MultiValueQueryStringParameters: map[string][]string{
+				"find2": {"hans", "gert"}, //multi-values parsed into array
+				"find":  {"koos"},         //field of type string can be parsed from one multi-value
+			},
+		})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if p3d, err := ctx.GetRequestParams(reflect.TypeOf(MyGetParams{})); err != nil {
+		t.Fatal(err)
+	} else {
+		get := p3d.(MyGetParams)
+		t.Logf("get: %+v", get)
+		if get.ID != 1 || get.Offset != 3 || get.Limit != 2 || get.Search != "4" || get.Find != "koos" ||
+			len(get.Find1) != 1 || get.Find1[0] != "sarel" ||
+			len(get.Find2) != 2 || get.Find2[0] != "hans" || get.Find2[1] != "gert" {
+			t.Fatalf("wrong values")
+		}
+	}
+}
-- 
GitLab