From 34f2caf5168fabeb85b11df5176900f586bd5a08 Mon Sep 17 00:00:00 2001
From: Jan Semmelink <jan@uafrica.com>
Date: Fri, 17 Sep 2021 14:12:48 +0200
Subject: [PATCH] Added errors.Is()

---
 errors/error.go       | 22 +++++++++++++++++++++-
 errors/errors.go      |  5 +++++
 errors/errors_test.go | 40 ++++++++++++++++++++++++++++++++++++++++
 3 files changed, 66 insertions(+), 1 deletion(-)

diff --git a/errors/error.go b/errors/error.go
index 53c7c7e..31b442d 100644
--- a/errors/error.go
+++ b/errors/error.go
@@ -20,7 +20,27 @@ type CustomError struct {
 
 //implement interface error:
 func (err CustomError) Error() string {
-	return err.Formatted(FormattingOptions{Causes: true})
+	return err.Formatted(FormattingOptions{Causes: false})
+}
+
+func Is(e1, e2 error) bool {
+	if e1WithIs, ok := e1.(ErrorWithIs); ok {
+		return e1WithIs.Is(e2)
+	}
+	return e1.Error() == e2.Error()
+}
+
+//Is() compares the message string of this or any cause to match the specified error message
+func (err CustomError) Is(specificError error) bool {
+	if err.message == specificError.Error() {
+		return true
+	}
+	if err.cause != nil {
+		if causeWithIs, ok := err.cause.(ErrorWithIs); ok {
+			return causeWithIs.Is(specificError)
+		}
+	}
+	return false
 }
 
 //implement github.com/pkg/errors: Cause
diff --git a/errors/errors.go b/errors/errors.go
index f81e17b..8feac4f 100644
--- a/errors/errors.go
+++ b/errors/errors.go
@@ -13,6 +13,11 @@ type ErrorWithCause interface {
 	Code() int
 }
 
+type ErrorWithIs interface {
+	error
+	Is(specificError error) bool
+}
+
 func New(message string) error {
 	err := &CustomError{
 		message: message,
diff --git a/errors/errors_test.go b/errors/errors_test.go
index d2ce899..f9a8e2c 100644
--- a/errors/errors_test.go
+++ b/errors/errors_test.go
@@ -2,6 +2,7 @@ package errors_test
 
 import (
 	"encoding/json"
+	"fmt"
 	"testing"
 
 	"gitlab.com/uafrica/go-utils/errors"
@@ -90,3 +91,42 @@ func f(i int) error {
 	}
 	return errors.Errorf("i=%d is odd", i)
 }
+
+func TestIs(t *testing.T) {
+	//in some condition program returns
+
+	for n := 0; n <= 4; n++ {
+		if err := ReadDb(n); err != nil {
+			switch {
+			case errors.Is(err, errNotFound):
+				t.Logf("n=%d failed with NOT FOUND", n)
+			case errors.Is(err, errDisabled):
+				t.Logf("n=%d failed because disabled", n)
+			default:
+				t.Logf("n=%d failed for unknown cause: %v", n, err)
+			}
+		} else {
+			t.Logf("n=%d worked", n)
+		}
+	}
+}
+
+var (
+	errNotFound = errors.New("Not Found")
+	errDisabled = errors.New("Disabled")
+)
+
+func ReadDb(x int) error {
+	switch x {
+	case 1:
+		return errNotFound
+	case 2:
+		return errDisabled
+	case 3:
+		return nil
+	case 4:
+		return fmt.Errorf("some silly bug")
+	default:
+		return errors.Errorf("invalid x=%d", x)
+	}
+}
-- 
GitLab