Skip to content

Commit 29217c2

Browse files
committed
Add the ability to check if an error stack contains an error value.
1 parent cae37a5 commit 29217c2

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

utils.go

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,22 +104,34 @@ func FieldsSlice(err error) []interface{} {
104104
return fields
105105
}
106106

107-
// Is reports whether err is an *Error of the given Kind(s). If the given error is an *Error, but
108-
// it has an empty kind, it will check the cause of that error recursively. If error is not an
109-
// *Error, or is nil, Is will return false.
110-
func Is(err error, kind ...Kind) bool {
107+
// Is reports whether the err is an *Error of the given kind/value. If the given kind is of type Kind/string, it will be
108+
// checked against the error's Kind. If the given kind is of any other type, it will be checked against the error's
109+
// cause. This is done recursively until a matching error is found. Calling Is with multiple kinds reports whether the
110+
// error is one of the given kind/values, not all of.
111+
func Is(err error, kind ...interface{}) bool {
111112
if err == nil {
112113
return false
113114
}
114115

115116
e, ok := err.(*Error)
116-
if ok && e.Kind != "" {
117-
for _, k := range kind {
118-
if e.Kind == k {
117+
if !ok {
118+
return false
119+
}
120+
121+
for _, k := range kind {
122+
switch val := k.(type) {
123+
case Kind, string:
124+
if e.Kind == val {
125+
return true
126+
}
127+
default:
128+
if e.Cause == val {
119129
return true
120130
}
121131
}
122-
} else if ok && e.Cause != nil {
132+
}
133+
134+
if e.Cause != nil {
123135
return Is(e.Cause, kind...)
124136
}
125137

utils_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package errors
22

33
import (
4+
"context"
45
"errors"
56
"strings"
67
"testing"
@@ -243,9 +244,17 @@ func TestFieldsSlice(t *testing.T) {
243244
})
244245
}
245246

247+
type errorType struct {}
248+
249+
func (t errorType) Error() string {
250+
return "error!"
251+
}
252+
253+
246254
func TestIs(t *testing.T) {
247255
kind1 := Kind("testing 1")
248256
kind2 := Kind("testing 2")
257+
kind3 := errorType{}
249258

250259
t.Run("should return false on nil error", func(t *testing.T) {
251260
assert.False(t, Is(nil))
@@ -280,6 +289,14 @@ func TestIs(t *testing.T) {
280289
err := Wrap(New(kind1))
281290
assert.True(t, Is(err, kind2, kind1))
282291
})
292+
293+
t.Run("should also be able to check value types", func(t *testing.T) {
294+
err := Wrap(kind3)
295+
assert.True(t, Is(err, kind3))
296+
297+
err = Wrap(context.Canceled)
298+
assert.True(t, Is(err, context.Canceled))
299+
})
283300
}
284301

285302
func TestMessage(t *testing.T) {

0 commit comments

Comments
 (0)