diff --git a/errors/error.go b/errors/error.go index 53c7c7e70bfa85f4c3100ef08cb64ebb4a23fba9..31b442d016f3d408f946688ef0895b5addeee3e3 100644 --- a/errors/error.go +++ b/errors/error.go @@ -20,7 +20,27 @@ type CustomError struct { //implement interface error: func (err CustomError) Error() string { - return err.Formatted(FormattingOptions{Causes: true}) + return err.Formatted(FormattingOptions{Causes: false}) +} + +func Is(e1, e2 error) bool { + if e1WithIs, ok := e1.(ErrorWithIs); ok { + return e1WithIs.Is(e2) + } + return e1.Error() == e2.Error() +} + +//Is() compares the message string of this or any cause to match the specified error message +func (err CustomError) Is(specificError error) bool { + if err.message == specificError.Error() { + return true + } + if err.cause != nil { + if causeWithIs, ok := err.cause.(ErrorWithIs); ok { + return causeWithIs.Is(specificError) + } + } + return false } //implement github.com/pkg/errors: Cause diff --git a/errors/errors.go b/errors/errors.go index f81e17b9b672f4c4561dcfb131c5cd1f0d6df9b7..8feac4f9d5e47eab499cfa2545bee3f8d4eb0aaa 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -13,6 +13,11 @@ type ErrorWithCause interface { Code() int } +type ErrorWithIs interface { + error + Is(specificError error) bool +} + func New(message string) error { err := &CustomError{ message: message, diff --git a/errors/errors_test.go b/errors/errors_test.go index d2ce8992c9f79a6adc630d9f692cacb2936ce58f..f9a8e2c6bce0d2cd2325d90fa5316cf7ab9d2f68 100644 --- a/errors/errors_test.go +++ b/errors/errors_test.go @@ -2,6 +2,7 @@ package errors_test import ( "encoding/json" + "fmt" "testing" "gitlab.com/uafrica/go-utils/errors" @@ -90,3 +91,42 @@ func f(i int) error { } return errors.Errorf("i=%d is odd", i) } + +func TestIs(t *testing.T) { + //in some condition program returns + + for n := 0; n <= 4; n++ { + if err := ReadDb(n); err != nil { + switch { + case errors.Is(err, errNotFound): + t.Logf("n=%d failed with NOT FOUND", n) + case errors.Is(err, errDisabled): + t.Logf("n=%d failed because disabled", n) + default: + t.Logf("n=%d failed for unknown cause: %v", n, err) + } + } else { + t.Logf("n=%d worked", n) + } + } +} + +var ( + errNotFound = errors.New("Not Found") + errDisabled = errors.New("Disabled") +) + +func ReadDb(x int) error { + switch x { + case 1: + return errNotFound + case 2: + return errDisabled + case 3: + return nil + case 4: + return fmt.Errorf("some silly bug") + default: + return errors.Errorf("invalid x=%d", x) + } +}