Skip to content

Commit 6e46573

Browse files
committed
Fix forwarding rpc status and improve api consistency
1 parent ab56f54 commit 6e46573

File tree

4 files changed

+129
-70
lines changed

4 files changed

+129
-70
lines changed

clientconn.go

+21-10
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,20 @@ func (cc *ClientConn) invoke(
5050
callOpts ...grpc.CallOption,
5151
) error {
5252
var serviceName, methodName string
53-
if method != Forward {
53+
forwarding := method == Forward
54+
if !forwarding {
5455
var err error
5556
serviceName, methodName, err = parseQualifiedMethod(method)
5657
if err != nil {
5758
return err
5859
}
60+
} else {
61+
if _, ok := req.(*RPC); !ok {
62+
return fmt.Errorf("[totem] invalid forwarding request type: %T (expected *totem.RPC)", req)
63+
}
64+
if _, ok := reply.(*RPC); !ok {
65+
return fmt.Errorf("[totem] invalid forwarding response type: %T (expected *totem.RPC)", reply)
66+
}
5967
}
6068

6169
md, ok := metadata.FromOutgoingContext(ctx)
@@ -127,18 +135,20 @@ func (cc *ClientConn) invoke(
127135
select {
128136
case rpc := <-future:
129137
resp := rpc.GetResponse()
130-
stat := resp.GetStatus()
131-
if err := stat.Err(); err != nil {
132-
cc.logger.Debug("received reply with error", "tag", rpc.Tag,
133-
"method", method,
134-
"error", err)
138+
if !forwarding {
139+
// don't unpack response status from replies to forwarded messages
140+
stat := resp.GetStatus()
141+
if err := stat.Err(); err != nil {
142+
cc.logger.Debug("received reply with error", "tag", rpc.Tag,
143+
"method", method,
144+
"error", err)
135145

136-
recordErrorStatus(span, stat)
137-
return err
146+
recordErrorStatus(span, stat)
147+
return err
148+
}
138149
}
139150

140-
cc.logger.Debug("received reply", "tag", rpc.Tag,
141-
"method", method)
151+
cc.logger.Debug("received reply", "tag", rpc.Tag, "method", method)
142152

143153
recordSuccess(span)
144154
cc.metrics.TrackSvcTxLatency(serviceName, methodName, time.Since(startTime))
@@ -158,6 +168,7 @@ func (cc *ClientConn) invoke(
158168
reply.Content = &RPC_Response{
159169
Response: resp,
160170
}
171+
reply.Metadata = FromMD(rpc.Metadata.ToMD())
161172
case protoadapt.MessageV2:
162173
if err := proto.Unmarshal(resp.GetResponse(), reply); err != nil {
163174
cc.logger.Error("received malformed response message", "tag", rpc.Tag,

stream.go

-4
Original file line numberDiff line numberDiff line change
@@ -591,10 +591,6 @@ func (sh *StreamController) handleRequest(ctx context.Context, msg *RPC, md meta
591591
info := first.MethodInfo[method]
592592
switch {
593593
case !info.IsServerStream && !info.IsClientStream:
594-
// very important to copy the message here, otherwise the tag
595-
// will be overwritten, and we need to preserve it to reply to
596-
// the original request
597-
// todo: does the above still apply?
598594
response, err := invoker.Invoke(ctx, msg)
599595
if err != nil {
600596
recordError(span, err)

totem_suite_test.go

+56-3
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,24 @@ import (
1212
"testing"
1313
"time"
1414

15+
"github.com/google/go-cmp/cmp"
1516
. "github.com/onsi/ginkgo/v2"
1617
. "github.com/onsi/gomega"
18+
"github.com/onsi/gomega/format"
1719
"go.opentelemetry.io/contrib/propagators/autoprop"
1820
"go.opentelemetry.io/otel"
1921
"go.opentelemetry.io/otel/exporters/jaeger"
2022
"go.opentelemetry.io/otel/sdk/resource"
2123
tracesdk "go.opentelemetry.io/otel/sdk/trace"
2224
semconv "go.opentelemetry.io/otel/semconv/v1.12.0"
2325

26+
"google.golang.org/genproto/googleapis/rpc/errdetails"
2427
"google.golang.org/grpc"
2528
"google.golang.org/grpc/codes"
2629
"google.golang.org/grpc/status"
30+
"google.golang.org/protobuf/encoding/prototext"
31+
"google.golang.org/protobuf/proto"
32+
"google.golang.org/protobuf/testing/protocmp"
2733
"google.golang.org/protobuf/types/known/emptypb"
2834

2935
. "github.com/kralicky/totem/test"
@@ -86,6 +92,7 @@ type testServer struct {
8692
}
8793

8894
func (ts *testServer) TestStream(stream Test_TestStreamServer) error {
95+
defer GinkgoRecover()
8996
defer ts.wg.Done()
9097
return ts.testCase.ServerHandler(stream)
9198
}
@@ -237,10 +244,16 @@ type errorServer struct {
237244
requestLimiter
238245
}
239246

240-
func (s *errorServer) Error(ctx context.Context, err *ErrorRequest) (*emptypb.Empty, error) {
247+
func (s *errorServer) Error(ctx context.Context, req *ErrorRequest) (*emptypb.Empty, error) {
241248
defer s.requestLimiter.Tick()
242-
if err.ReturnError {
243-
return nil, status.Error(codes.Aborted, "error")
249+
err, _ := status.New(codes.Aborted, "error").WithDetails(&errdetails.ErrorInfo{Reason: "reason", Domain: "domain", Metadata: map[string]string{"key": "value"}})
250+
251+
// if err := grpc.SetHeader(ctx, metadata.Pairs("errorKey", "errorValue")); err != nil {
252+
// panic(err)
253+
// }
254+
255+
if req.ReturnError {
256+
return nil, err.Err()
244257
} else {
245258
return &emptypb.Empty{}, nil
246259
}
@@ -308,3 +321,43 @@ func (s *countServer) Count(in *Number, stream Count_CountServer) error {
308321
}
309322
return nil
310323
}
324+
325+
func ProtoEqual(expected proto.Message) *ProtoMatcher {
326+
return &ProtoMatcher{
327+
Expected: expected,
328+
}
329+
}
330+
331+
type ProtoMatcher struct {
332+
Expected proto.Message
333+
}
334+
335+
func (matcher *ProtoMatcher) Match(actual any) (success bool, err error) {
336+
if actual == nil && matcher.Expected == nil {
337+
return false, fmt.Errorf("Refusing to compare <nil> to <nil>.\nBe explicit and use BeNil() instead. This is to avoid mistakes where both sides of an assertion are erroneously uninitialized")
338+
}
339+
if _, ok := actual.(proto.Message); !ok {
340+
return false, fmt.Errorf("ProtoMatcher expects a proto.Message. Got:\n%s", format.Object(actual, 1))
341+
}
342+
return proto.Equal(actual.(proto.Message), matcher.Expected), nil
343+
}
344+
345+
func (matcher *ProtoMatcher) FailureMessage(actual any) (message string) {
346+
diff := cmp.Diff(actual.(proto.Message), matcher.Expected, protocmp.Transform())
347+
return fmt.Sprintf("Expected\n%s\n%s\n%s\ndiff:\n%s",
348+
format.IndentString(prototext.Format(actual.(proto.Message)), 1),
349+
"to equal",
350+
format.IndentString(prototext.Format(matcher.Expected), 1),
351+
diff,
352+
)
353+
}
354+
355+
func (matcher *ProtoMatcher) NegatedFailureMessage(actual any) (message string) {
356+
diff := cmp.Diff(actual.(proto.Message), matcher.Expected, protocmp.Transform())
357+
return fmt.Sprintf("Expected\n%s\n%s\n%s\ndiff:\n%s",
358+
format.IndentString(prototext.Format(actual.(proto.Message)), 1),
359+
"not to equal",
360+
format.IndentString(prototext.Format(matcher.Expected), 1),
361+
diff,
362+
)
363+
}

totem_test.go

+52-53
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,35 @@ import (
99
sync "sync"
1010
"time"
1111

12+
"github.com/google/go-cmp/cmp"
1213
"github.com/kralicky/totem"
1314
"github.com/kralicky/totem/test"
1415
. "github.com/onsi/ginkgo/v2"
1516
. "github.com/onsi/gomega"
17+
"github.com/onsi/gomega/format"
18+
"google.golang.org/genproto/googleapis/rpc/errdetails"
19+
statuspb "google.golang.org/genproto/googleapis/rpc/status"
1620
"google.golang.org/grpc"
1721
"google.golang.org/grpc/codes"
1822
"google.golang.org/grpc/metadata"
1923
"google.golang.org/grpc/status"
24+
"google.golang.org/protobuf/encoding/prototext"
2025
"google.golang.org/protobuf/proto"
26+
"google.golang.org/protobuf/testing/protocmp"
27+
"google.golang.org/protobuf/types/known/anypb"
2128
"google.golang.org/protobuf/types/known/durationpb"
2229
)
2330

2431
var (
2532
timeout = time.Second * 6
2633
)
2734

28-
var _ = Describe("Test", func() {
35+
var _ = FDescribe("Test", func() {
2936
It("should work with two different servers", func() {
3037
a, b := make(chan struct{}), make(chan struct{})
3138
tc := testCase{
3239
ServerHandler: func(stream test.Test_TestStreamServer) error {
40+
GinkgoHelper()
3341
ts, err := totem.NewServer(stream)
3442
if err != nil {
3543
return err
@@ -184,7 +192,9 @@ var _ = Describe("Test", func() {
184192
return err
185193
}
186194
incSrv := incrementServer{}
195+
errSrv := errorServer{}
187196
test.RegisterIncrementServer(ts, &incSrv)
197+
test.RegisterErrorServer(ts, &errSrv)
188198
_, errC := ts.Serve()
189199

190200
return <-errC
@@ -212,72 +222,57 @@ var _ = Describe("Test", func() {
212222
ctx, ca := context.WithTimeout(context.Background(), timeout)
213223
defer ca()
214224
err = cc.Invoke(ctx, totem.Forward, req, reply)
215-
if err != nil {
216-
return err
217-
}
225+
Expect(err).NotTo(HaveOccurred())
218226

219227
respValue := &test.Number{}
220228
err = proto.Unmarshal(reply.GetResponse().GetResponse(), respValue)
221229

222230
if err != nil {
223231
return err
224232
}
225-
if respValue.GetValue() != 1235 {
226-
return fmt.Errorf("expected 1235, got %d", respValue.GetValue())
227-
}
233+
Expect(respValue.GetValue()).To(Equal(int64(1235)))
228234

229-
close(done)
230-
return <-errC
231-
},
232-
}
233-
tc.Run(done)
234-
})
235-
236-
It("should forward raw RPCs and receive regular proto messages", func() {
237-
done := make(chan struct{})
238-
tc := testCase{
239-
ServerHandler: func(stream test.Test_TestStreamServer) error {
240-
ts, err := totem.NewServer(stream)
241-
if err != nil {
242-
return err
243-
}
244-
incSrv := incrementServer{}
245-
test.RegisterIncrementServer(ts, &incSrv)
246-
_, errC := ts.Serve()
247-
248-
return <-errC
249-
},
250-
ClientHandler: func(stream test.Test_TestStreamClient) error {
251-
ts, err := totem.NewServer(stream)
252-
if err != nil {
253-
return err
254-
}
255-
256-
cc, errC := ts.Serve()
257-
258-
reqBytes, _ := proto.Marshal(&test.Number{
259-
Value: 1234,
260-
})
261-
req := &totem.RPC{
262-
ServiceName: "test.Increment",
263-
MethodName: "Inc",
235+
errReq := &test.ErrorRequest{ReturnError: true}
236+
errReqBytes, _ := proto.Marshal(errReq)
237+
req = &totem.RPC{
238+
ServiceName: "test.Error",
239+
MethodName: "Error",
264240
Content: &totem.RPC_Request{
265-
Request: reqBytes,
241+
Request: errReqBytes,
266242
},
267243
}
268-
reply := &test.Number{}
269-
270-
ctx, ca := context.WithTimeout(context.Background(), timeout)
271-
defer ca()
244+
reply = &totem.RPC{}
272245
err = cc.Invoke(ctx, totem.Forward, req, reply)
273-
if err != nil {
274-
return err
246+
Expect(err).To(BeNil())
247+
errinfo, _ := anypb.New(&errdetails.ErrorInfo{
248+
Reason: "reason",
249+
Domain: "domain",
250+
Metadata: map[string]string{"key": "value"},
251+
})
252+
expected := &totem.RPC{
253+
Content: &totem.RPC_Response{
254+
Response: &totem.Response{
255+
Response: nil,
256+
StatusProto: &statuspb.Status{
257+
Code: int32(codes.Aborted),
258+
Message: "error",
259+
Details: []*anypb.Any{
260+
errinfo,
261+
},
262+
},
263+
},
264+
},
265+
// Metadata: totem.FromMD(metadata.Pairs("errorKey", "errorValue")),
275266
}
276-
277-
if reply.GetValue() != 1235 {
278-
return fmt.Errorf("expected 1235, got %d", reply.GetValue())
267+
if !proto.Equal(reply, expected) {
268+
diff := cmp.Diff(reply, expected, protocmp.Transform())
269+
return fmt.Errorf("Expected\n%s\n%s\n%s\ndiff:\n%s",
270+
format.IndentString(prototext.Format(reply), 1),
271+
"to equal",
272+
format.IndentString(prototext.Format(expected), 1),
273+
diff,
274+
)
279275
}
280-
281276
close(done)
282277
return <-errC
283278
},
@@ -1093,6 +1088,7 @@ var _ = Describe("Test", func() {
10931088
})
10941089

10951090
func checkIncrement(cc grpc.ClientConnInterface) {
1091+
GinkgoHelper()
10961092
incClient := test.NewIncrementClient(cc)
10971093
ctx := metadata.AppendToOutgoingContext(context.Background(), "test", "increment")
10981094
result, err := incClient.Inc(ctx, &test.Number{
@@ -1103,6 +1099,7 @@ func checkIncrement(cc grpc.ClientConnInterface) {
11031099
}
11041100

11051101
func checkDecrement(cc grpc.ClientConnInterface) {
1102+
GinkgoHelper()
11061103
decClient := test.NewDecrementClient(cc)
11071104
ctx := metadata.AppendToOutgoingContext(context.Background(), "test", "decrement")
11081105
result, err := decClient.Dec(ctx, &test.Number{
@@ -1113,6 +1110,7 @@ func checkDecrement(cc grpc.ClientConnInterface) {
11131110
}
11141111

11151112
func checkMultiply(cc grpc.ClientConnInterface) {
1113+
GinkgoHelper()
11161114
mulClient := test.NewMultiplyClient(cc)
11171115
ctx := metadata.AppendToOutgoingContext(context.Background(), "test", "multiply")
11181116
result, err := mulClient.Mul(ctx, &test.Operands{
@@ -1124,6 +1122,7 @@ func checkMultiply(cc grpc.ClientConnInterface) {
11241122
}
11251123

11261124
func checkHash(cc grpc.ClientConnInterface) {
1125+
GinkgoHelper()
11271126
hashClient := test.NewHashClient(cc)
11281127
ctx := metadata.AppendToOutgoingContext(context.Background(), "test", "hash")
11291128
result, err := hashClient.Hash(ctx, &test.String{

0 commit comments

Comments
 (0)