diff --git a/handler_utils/api.go b/handler_utils/api.go index 90e25aa292cba0a5580e814e04bbda1f867ce074..ce89f0c442821ddab9231dbe9edfb2fec525bc20 100644 --- a/handler_utils/api.go +++ b/handler_utils/api.go @@ -4,6 +4,9 @@ import ( "github.com/aws/aws-lambda-go/events" "gitlab.com/uafrica/go-utils/errors" "gitlab.com/uafrica/go-utils/logs" + "reflect" + "strconv" + "strings" ) // ValidateAPIEndpoints checks that all API endpoints are correctly defined using one of the supported handler types @@ -49,3 +52,127 @@ func ValidateAPIEndpoints(endpoints map[string]map[string]interface{}) (map[stri logs.Info("Checked %d legacy and %d new handlers\n", countLegacy, countHandler) return endpoints, nil } + +func ValidateRequestParams(request *events.APIGatewayProxyRequest, paramsStructType reflect.Type) (reflect.Value, error) { + paramValues := map[string]interface{}{} + for n, v := range request.QueryStringParameters { + paramValues[n] = v + } + paramsStructValuePtr := reflect.New(paramsStructType) + for i := 0; i < paramsStructType.NumField(); i++ { + f := paramsStructType.Field(i) + n := (strings.SplitN(f.Tag.Get("json"), ",", 2))[0] + if n == "" { + n = strings.ToLower(f.Name) + } + if n == "" || n == "-" { + continue + } + + //get value(s) from query string + var paramStrValues []string + if paramStrValue, isDefined := request.QueryStringParameters[n]; isDefined { + paramStrValues = []string{paramStrValue} //single value + } else { + paramStrValues = request.MultiValueQueryStringParameters[n] + } + if len(paramStrValues) == 0 { + continue //param has no value specified in URL + } + + //param is defined >=1 times in URL + if f.Type.Kind() == reflect.Slice { + //iterate over all specified values + for _, paramStrValue := range paramStrValues { + newValuePtr := reflect.New(f.Type.Elem()) + if err := setParamFromStr( + newValuePtr.Elem(), + paramStrValue); err != nil { + return reflect.Value{}, errors.Wrapf(err, "failed to set %s[%d]=%s", n, i, paramStrValues[0]) + } + paramsStructValuePtr.Elem().Field(i).Set(reflect.Append(paramsStructValuePtr.Elem().Field(i), newValuePtr.Elem())) + } + } else { + if len(paramStrValues) > 1 { + return reflect.Value{}, errors.Errorf("%s does not support >1 values(%v)", n, strings.Join(paramStrValues, ",")) + } + //single value specified + if err := setParamFromStr(paramsStructValuePtr.Elem().Field(i), paramStrValues[0]); err != nil { + return reflect.Value{}, errors.Wrapf(err, "failed to set %s=%s", n, paramStrValues[0]) + } + } + } //for each param struct field + + return paramsStructValuePtr, nil +} + + + +func setParamFromStr(fieldValue reflect.Value, paramStrValue string) error { + switch fieldValue.Type().Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + //parse to int for this struct field + if i64, err := strconv.ParseInt(paramStrValue, 10, 64); err != nil { + return errors.Errorf("%s is not a number", paramStrValue) + } else { + switch fieldValue.Type().Kind() { + case reflect.Int: + fieldValue.Set(reflect.ValueOf(int(i64))) + case reflect.Int8: + fieldValue.Set(reflect.ValueOf(int8(i64))) + case reflect.Int16: + fieldValue.Set(reflect.ValueOf(int16(i64))) + case reflect.Int32: + fieldValue.Set(reflect.ValueOf(int32(i64))) + case reflect.Int64: + fieldValue.Set(reflect.ValueOf(i64)) + } + } + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + //parse to int for this struct field + if u64, err := strconv.ParseUint(paramStrValue, 10, 64); err != nil { + return errors.Errorf("%s is not a number", paramStrValue) + } else { + switch fieldValue.Type().Kind() { + case reflect.Uint: + fieldValue.Set(reflect.ValueOf(uint(u64))) + case reflect.Uint8: + fieldValue.Set(reflect.ValueOf(uint8(u64))) + case reflect.Uint16: + fieldValue.Set(reflect.ValueOf(uint16(u64))) + case reflect.Uint32: + fieldValue.Set(reflect.ValueOf(uint32(u64))) + case reflect.Uint64: + fieldValue.Set(reflect.ValueOf(u64)) + } + } + + case reflect.Bool: + bs := strings.ToLower(paramStrValue) + if bs == "true" || bs == "yes" || bs == "1" { + fieldValue.Set(reflect.ValueOf(true)) + } + + case reflect.String: + fieldValue.Set(reflect.ValueOf(paramStrValue)) + + case reflect.Float32: + if f32, err := strconv.ParseFloat(paramStrValue, 32); err != nil { + return errors.Wrapf(err, "invalid float") + } else { + fieldValue.Set(reflect.ValueOf(float32(f32))) + } + + case reflect.Float64: + if f64, err := strconv.ParseFloat(paramStrValue, 64); err != nil { + return errors.Wrapf(err, "invalid float") + } else { + fieldValue.Set(reflect.ValueOf(f64)) + } + + default: + return errors.Errorf("unsupported type %v", fieldValue.Type().Kind()) + } //switch param struct field + return nil +}