diff --git a/api/context.go b/api/context.go index 7cc3f528975041e887ab4f44365d8e6cc5f60fd2..eea6c5ca30dc49bf99117ef1869d2daab1dc9eb6 100644 --- a/api/context.go +++ b/api/context.go @@ -58,13 +58,8 @@ func (ctx *apiContext) LogAPIRequestAndResponse(res events.APIGatewayProxyRespon //allocate struct for params, populate it from the URL parameters then validate and return the struct func (ctx apiContext) GetRequestParams(paramsStructType reflect.Type) (interface{}, error) { - paramValues := map[string]interface{}{} - for n, v := range ctx.request.QueryStringParameters { - paramValues[n] = v - } paramsStructValuePtr := reflect.New(paramsStructType) - - if err := ctx.extract("params", paramsStructType, paramsStructValuePtr.Elem()); err != nil { + if err := ctx.setParamsInStruct("params", paramsStructType, paramsStructValuePtr.Elem()); err != nil { return nil, errors.Wrapf(err, "failed to put query param values into struct") } if err := ctx.applyClaim("params", paramsStructValuePtr.Interface()); err != nil { @@ -78,30 +73,37 @@ func (ctx apiContext) GetRequestParams(paramsStructType reflect.Type) (interface return paramsStructValuePtr.Elem().Interface(), nil } -func (ctx apiContext) extract(name string, t reflect.Type, v reflect.Value) error { +//extract params into a struct value +func (ctx apiContext) setParamsInStruct(name string, t reflect.Type, v reflect.Value) error { for i := 0; i < t.NumField(); i++ { - f := t.Field(i) - switch f.Type.Kind() { - case reflect.Struct: - if err := ctx.extract(name+"."+f.Name, t.Field(i).Type, v.Field(i)); err != nil { - return errors.Wrapf(err, "failed to fill sub %s.%s", name, f.Name) + tf := t.Field(i) + //enter into anonymous sub-structs + if tf.Anonymous { + if tf.Type.Kind() == reflect.Struct { + if err := ctx.setParamsInStruct(name+"."+tf.Name, t.Field(i).Type, v.Field(i)); err != nil { + return errors.Wrapf(err, "failed on parameters %s.%s", name, tf.Name) + } + continue } - continue - default: + return errors.Errorf("parameters cannot parse into anonymous %s field %s", tf.Type.Kind(), tf.Type.Name()) } - n := (strings.SplitN(f.Tag.Get("json"), ",", 2))[0] + //named field: + //use name from json tag, else lowercase of field name + n := (strings.SplitN(tf.Tag.Get("json"), ",", 2))[0] if n == "" { - n = strings.ToLower(f.Name) + n = strings.ToLower(tf.Name) } if n == "" || n == "-" { - continue + continue //skip fields without name } - //get value(s) from query string + //see if this named param was specified var paramStrValues []string if paramStrValue, isDefined := ctx.request.QueryStringParameters[n]; isDefined { + //specified once in URL if len(paramStrValue) >= 2 && paramStrValue[0] == '[' && paramStrValue[len(paramStrValue)-1] == ']' { + //specified as CSV inside [...] e.g. id=[1,2,3] csvReader := csv.NewReader(strings.NewReader(paramStrValue[1 : len(paramStrValue)-1])) var err error paramStrValues, err = csvReader.Read() @@ -109,43 +111,58 @@ func (ctx apiContext) extract(name string, t reflect.Type, v reflect.Value) erro return errors.Wrapf(err, "invalid CSV: [%s]", paramStrValue) } } else { - paramStrValues = []string{paramStrValue} //single value + //specified as single value only e.g. id=1 + paramStrValues = []string{paramStrValue} } } else { + //specified multiple times e.g. id=1&id=2&id=3 paramStrValues = ctx.request.MultiValueQueryStringParameters[n] } if len(paramStrValues) == 0 { continue //param has no value specified in URL } + valueField := v.Field(i) + if valueField.Kind() == reflect.Ptr { + valueField.Set(reflect.New(valueField.Type().Elem())) + valueField = valueField.Elem() + } + //param is defined >=1 times in URL - if f.Type.Kind() == reflect.Slice { - //iterate over all specified values - for index, paramStrValue := range paramStrValues { - newValuePtr := reflect.New(f.Type.Elem()) - if err := reflection.SetValue(newValuePtr.Elem(), paramStrValue); err != nil { - return errors.Wrapf(err, "failed to set %s[%d]=%s", n, index, paramStrValues[0]) + if tf.Type.Kind() == reflect.Slice { + //this param struct field is a slice, iterate over all specified values + for i, paramStrValue := range paramStrValues { + paramValue, err := parseParamValue(paramStrValue, tf.Type.Elem()) + if err != nil { + return errors.Wrapf(err, "invalid %s[%d]", n, i) } - v.Field(i).Set(reflect.Append(v.Field(i), newValuePtr.Elem())) + valueField.Set(reflect.Append(valueField, paramValue)) } } else { if len(paramStrValues) > 1 { - return errors.Errorf("%s does not support >1 values(%v)", n, strings.Join(paramStrValues, ",")) + return errors.Errorf("parameter %s does not support multiple values [%s]", n, strings.Join(paramStrValues, ",")) } //single value specified - valueField := v.Field(i) - if valueField.Kind() == reflect.Ptr { - valueField.Set(reflect.New(valueField.Type().Elem())) - valueField = valueField.Elem() - } - if err := reflection.SetValue(valueField, paramStrValues[0]); err != nil { - return errors.Wrapf(err, "failed to set %s=%s", n, paramStrValues[0]) + paramValue, err := parseParamValue(paramStrValues[0], valueField.Type()) + if err != nil { + return errors.Wrapf(err, "invalid %s", n) } + valueField.Set(paramValue) } } //for each param struct field return nil } +func parseParamValue(s string, t reflect.Type) (reflect.Value, error) { + newValuePtr := reflect.New(t) + if err := json.Unmarshal([]byte("\""+s+"\""), newValuePtr.Interface()); err != nil { + if err := json.Unmarshal([]byte(s), newValuePtr.Interface()); err != nil { + return newValuePtr.Elem(), errors.Wrapf(err, "invalid \"%s\"", s) + } + } + return newValuePtr.Elem(), nil +} + func (ctx apiContext) GetRequestBody(requestStructType reflect.Type) (interface{}, error) { requestStructValuePtr := reflect.New(requestStructType) err := json.Unmarshal([]byte(ctx.request.Body), requestStructValuePtr.Interface()) diff --git a/api/params_test.go b/api/params_test.go index cb5dc0d1c0366d1282f9e69037a502cfdfeec707..b9586e99c1a875103ad79a47e574fab20d197551 100644 --- a/api/params_test.go +++ b/api/params_test.go @@ -2,8 +2,10 @@ package api_test import ( "context" + "encoding/json" "reflect" "testing" + "time" "github.com/aws/aws-lambda-go/events" "gitlab.com/uafrica/go-utils/api" @@ -34,6 +36,7 @@ func TestNested(t *testing.T) { ctx, err = api.New("request-id", nil).NewContext( context.Background(), "123", + //all URL params are specified as string values events.APIGatewayProxyRequest{ QueryStringParameters: map[string]string{ "a": "1", //must be written into P3.P2.P1.A @@ -47,11 +50,11 @@ func TestNested(t *testing.T) { }, }) if err != nil { - t.Fatal(err) + t.Fatalf("ERROR: %+v", err) } if p3d, err := ctx.GetRequestParams(reflect.TypeOf(P3{})); err != nil { - t.Fatal(err) + t.Fatalf("ERROR: %+v", err) } else { p3 := p3d.(P3) t.Logf("p3: %+v", p3) @@ -67,6 +70,77 @@ func TestNested(t *testing.T) { } } +type ParamTypes struct { + GetParams + Nr int64 `json:"nr"` + Name string `json:"name"` + NrOpt *int64 `json:"nr_opt"` + NameOpt *string `json:"name_opt"` + Time1 time.Time `json:"time1"` + Time2 *time.Time `json:"time2"` + Dur1 time.Duration `json:"dur1"` + Dur2 *time.Duration `json:"dur2"` + + //lists of values + NrList []int64 `json:"nrs"` + NameList []string `json:"names"` + NrOptList []*int64 `json:"nrs_opt"` + NameOptList []*string `json:"names_opt"` + Time1List []time.Time `json:"time1s"` + Time2List []*time.Time `json:"time2s"` + Dur1List []time.Duration `json:"dur1s"` + Dur2List []*time.Duration `json:"dur2s"` +} + +func TestTypes(t *testing.T) { + logger.SetGlobalLevel(logger.LevelDebug) + logger.SetGlobalFormat(logger.NewConsole()) + var ctx api.Context + var err error + ctx, err = api.New("request-id", nil).NewContext( + context.Background(), + "123", + //all URL params are specified as string values + events.APIGatewayProxyRequest{ + QueryStringParameters: map[string]string{ + "nr": "1", + "name": "name2", + "nr_opt": "3", + "name_opt": "name4", + "limit": "5", + "time1": "2021-11-23T00:00:00+00:00", + "time2": "2021-11-23T00:00:00+00:00", + "dur1": "4", //nanoseconds + "dur2": "4", //nanoseconds + "nrs": "[1,2,3]", + "nrs_opt": "[4,5,6]", + "names": "[A,B,C]", + "names_opt": "[D,E,F]", + "time1s": "[2021-11-23T00:00:00+00:00]", + "dur1s": "[4,5,6]", //nanoseconds + }, + MultiValueQueryStringParameters: map[string][]string{ + "dur2s": {"11", "12", "13"}, + "time2s": {"2021-11-23T00:00:00+00:00", "2021-11-23T00:00:00+00:00", "2021-11-23T00:00:00+00:00"}, + }, + }) + if err != nil { + t.Fatalf("ERROR: %+v", err) + } + + if pd, err := ctx.GetRequestParams(reflect.TypeOf(ParamTypes{})); err != nil { + t.Fatalf("ERROR: %+v", err) + } else { + p := pd.(ParamTypes) + t.Logf("p: %+v", p) + if p.Nr != 1 || p.Name != "name2" || p.NrOpt == nil || *p.NrOpt != 3 || p.NameOpt == nil || *p.NameOpt != "name4" || p.Limit != 5 { + t.Errorf("Wrong values: %+v", p) + } + jsonParams, _ := json.Marshal(p) + t.Logf("params: %s", string(jsonParams)) + } +} + type PageParams struct { Limit int64 `json:"limit"` Offset int64 `json:"offset"` @@ -94,6 +168,7 @@ func TestGet(t *testing.T) { ctx, err = api.New("request-id", nil).NewContext( context.Background(), "123", + //all URL params are specified as string values events.APIGatewayProxyRequest{ QueryStringParameters: map[string]string{ "id": "1", @@ -109,11 +184,11 @@ func TestGet(t *testing.T) { }, }) if err != nil { - t.Fatal(err) + t.Fatalf("ERROR: %+v", err) } if p3d, err := ctx.GetRequestParams(reflect.TypeOf(MyGetParams{})); err != nil { - t.Fatal(err) + t.Fatalf("ERROR: %+v", err) } else { get := p3d.(MyGetParams) t.Logf("get: %+v", get)