diff --git a/each_test.go b/each_test.go index 596417d..4199afd 100644 --- a/each_test.go +++ b/each_test.go @@ -43,13 +43,13 @@ func TestEach(t *testing.T) { func TestEachWithContext(t *testing.T) { rule := Each(WithContext(func(ctx context.Context, value interface{}) error { - if !strings.Contains(value.(string), ctx.Value("contains").(string)) { + if !strings.Contains(value.(string), ctx.Value(contains).(string)) { return errors.New("unexpected value") } return nil })) - ctx1 := context.WithValue(context.Background(), "contains", "abc") - ctx2 := context.WithValue(context.Background(), "contains", "xyz") + ctx1 := context.WithValue(context.Background(), contains, "abc") + ctx2 := context.WithValue(context.Background(), contains, "xyz") tests := []struct { tag string diff --git a/when.go b/when.go index 7bcdff5..2c785a0 100644 --- a/when.go +++ b/when.go @@ -28,16 +28,14 @@ func (r WhenRule) ValidateWithContext(ctx context.Context, value interface{}) er if r.condition { if ctx == nil { return Validate(value, r.rules...) - } else { - return ValidateWithContext(ctx, value, r.rules...) } + return ValidateWithContext(ctx, value, r.rules...) } if ctx == nil { return Validate(value, r.elseRules...) - } else { - return ValidateWithContext(ctx, value, r.elseRules...) } + return ValidateWithContext(ctx, value, r.elseRules...) } // Else returns a validation rule that executes the given list of rules when the condition is false. diff --git a/when_test.go b/when_test.go index 589923d..51c7ec4 100644 --- a/when_test.go +++ b/when_test.go @@ -57,15 +57,21 @@ func TestWhen(t *testing.T) { } } +type ctxKey int + +const ( + contains ctxKey = iota +) + func TestWhenWithContext(t *testing.T) { rule := WithContext(func(ctx context.Context, value interface{}) error { - if !strings.Contains(value.(string), ctx.Value("contains").(string)) { + if !strings.Contains(value.(string), ctx.Value(contains).(string)) { return errors.New("unexpected value") } return nil }) - ctx1 := context.WithValue(context.Background(), "contains", "abc") - ctx2 := context.WithValue(context.Background(), "contains", "xyz") + ctx1 := context.WithValue(context.Background(), contains, "abc") + ctx2 := context.WithValue(context.Background(), contains, "xyz") tests := []struct { tag string