Skip to content

Commit 269c4c3

Browse files
authored
Merge pull request #26 from srvc/izumin5210/slice
Check error comparability for avoiding panic
2 parents c6bd1f3 + 4dc2444 commit 269c4c3

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

pkgerrors.go

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package fail
22

33
import (
4+
"reflect"
45
"strings"
56

67
pkgerrors "github.com/pkg/errors"
@@ -79,8 +80,15 @@ func extractPkgError(err error) *pkgError {
7980
break
8081
}
8182

82-
if len(stackTraces) == 0 && rootErr == err {
83-
return nil
83+
if len(stackTraces) == 0 {
84+
ret, et := reflect.TypeOf(rootErr), reflect.TypeOf(err)
85+
if ret != nil && et != nil && ret.Comparable() && et.Comparable() {
86+
if rootErr == err {
87+
return nil
88+
}
89+
} else {
90+
return nil
91+
}
8492
}
8593

8694
// Extract annotated messages by removing the trailing message.

pkgerrors_test.go

+28
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package fail
22

33
import (
44
"errors"
5+
"strings"
56
"testing"
67

78
pkgerrors "github.com/pkg/errors"
@@ -20,6 +21,13 @@ func TestExtractPkgError(t *testing.T) {
2021
assert.Nil(t, pkgErr)
2122
})
2223

24+
t.Run("slice error", func(t *testing.T) {
25+
err := errorSlice{errors.New("error")}
26+
27+
pkgErr := extractPkgError(err)
28+
assert.Nil(t, pkgErr)
29+
})
30+
2331
t.Run("pkg/errors.New", func(t *testing.T) {
2432
err := pkgErrorsNew("message")
2533

@@ -84,6 +92,16 @@ func TestExtractPkgError(t *testing.T) {
8492
assert.NotEmpty(t, pkgErr.StackTrace)
8593
assert.Equal(t, "pkgErrorsWrap", pkgErr.StackTrace[0].Func)
8694
})
95+
96+
t.Run("with slice error", func(t *testing.T) {
97+
err0 := errorSlice{errors.New("error")}
98+
err1 := pkgErrorsWrap(err0, "message")
99+
100+
pkgErr := extractPkgError(err1)
101+
assert.NotNil(t, pkgErr)
102+
assert.Equal(t, err0, pkgErr.Err)
103+
assert.NotEmpty(t, pkgErr.StackTrace)
104+
})
87105
})
88106

89107
t.Run("pkg/errors.WithMessage", func(t *testing.T) {
@@ -156,3 +174,13 @@ func pkgErrorsNew(msg string) error {
156174
func pkgErrorsWrap(err error, msg string) error {
157175
return pkgerrors.Wrap(err, msg)
158176
}
177+
178+
type errorSlice []error
179+
180+
func (s errorSlice) Error() string {
181+
msg := make([]string, len(s))
182+
for i, e := range s {
183+
msg[i] = e.Error()
184+
}
185+
return strings.Join(msg, ": ")
186+
}

0 commit comments

Comments
 (0)