From 2f6c30fde4a278d4f748cd7fc4c44a385a4db54e Mon Sep 17 00:00:00 2001 From: Shaik Zakir Hussain Date: Mon, 9 Jun 2025 11:32:54 +0530 Subject: [PATCH 1/2] drpcserver: add server interceptor support in drpc Add support for server interceptors in drpc. Interceptors are taken as server options, while instantiating the server. All interceptors added via server options are chained into a single server interceptor. Unlike grpc, here there is no separation between unary and stream interceptors because the unerdlying logic of handling the rpc calls is the same i.e. `drpc.Handler`'s `handleRPC` implementation. In grpc, we have separate processing for unary rpc's and stream rpc's. This common processing for both stream and unary rpcs has helped condense the interceptor interfaces into a single interface that works for both use-cases, essentially simplifying the logic of initializing and wiring of interceptors. Fixes: #147622 Epic: CRDB-51168 --- .gitignore | 3 +- drpcserver/server.go | 82 +++++++- drpcserver/server_interceptor.go | 24 +++ drpcserver/server_interceptor_test.go | 263 ++++++++++++++++++++++++++ 4 files changed, 365 insertions(+), 7 deletions(-) create mode 100644 drpcserver/server_interceptor.go create mode 100644 drpcserver/server_interceptor_test.go diff --git a/.gitignore b/.gitignore index b5c62eb..83272f1 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ .vscode vendor result -.idea/* \ No newline at end of file +.idea/* +.DS_Store \ No newline at end of file diff --git a/drpcserver/server.go b/drpcserver/server.go index 71fc457..87e1d53 100644 --- a/drpcserver/server.go +++ b/drpcserver/server.go @@ -33,6 +33,66 @@ type Options struct { // CollectStats controls whether the server should collect stats on the // rpcs it serves. CollectStats bool + + serverInt ServerInterceptor + serverInts []ServerInterceptor +} + +// A ServerOption sets options such as server interceptors. +type ServerOption func(options *Options) + +// WithChainServerInterceptor creates a ServerOption that chains multiple server interceptors, +// with the first being the outermost wrapper and the last being the innermost. +func WithChainServerInterceptor(ints ...ServerInterceptor) ServerOption { + return func(opt *Options) { + opt.serverInts = append(opt.serverInts, ints...) + } +} + +// chainServerInterceptors chains all server interceptors in the Options into a single interceptor. +// The combined chained interceptor is stored in opts.serverInt. The interceptors are invoked in the order they were added. +// +// Example usage: +// +// Interceptors are typically added using WithChainServerInterceptor when creating the server. +// The NewWithOptions function calls chainServerInterceptors internally to process these. +// server := drpcserver.NewWithOptions( +// drpcHandler, +// drpcserver.Options{}, // base server options +// drpcserver.WithChainServerInterceptor(loggingInterceptor, metricsInterceptor), +// ) +// +// // Chain the interceptors +// chainServerInterceptors(server) +// // server.opts.serverInt now contains the chained server interceptors. +func chainServerInterceptors(s *Server) { + switch n := len(s.opts.serverInts); n { + case 0: + s.opts.serverInt = nil + case 1: + s.opts.serverInt = s.opts.serverInts[0] + default: + s.opts.serverInt = func( + ctx context.Context, + rpc string, + stream drpc.Stream, + handler drpc.Handler, + ) error { + chained := handler + for i := n - 1; i >= 0; i-- { + next := chained + interceptor := s.opts.serverInts[i] + chainedFn := func( + stream drpc.Stream, + rpc string, + ) error { + return interceptor(ctx, rpc, stream, next) + } + chained = HandlerFunc(chainedFn) + } + return chained.HandleRPC(stream, rpc) + } + } } // Server is an implementation of drpc.Server to serve drpc connections. @@ -51,7 +111,7 @@ func New(handler drpc.Handler) *Server { // NewWithOptions constructs a new Server using the provided options to tune // how the drpc connections are handled. -func NewWithOptions(handler drpc.Handler, opts Options) *Server { +func NewWithOptions(handler drpc.Handler, opts Options, sopts ...ServerOption) *Server { s := &Server{ opts: opts, handler: handler, @@ -61,6 +121,10 @@ func NewWithOptions(handler drpc.Handler, opts Options) *Server { drpcopts.SetManagerStatsCB(&s.opts.Manager.Internal, s.getStats) s.stats = make(map[string]*drpcstats.Stats) } + for _, opt := range sopts { + opt(&s.opts) + } + chainServerInterceptors(s) return s } @@ -105,7 +169,7 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { if err != nil { return errs.Wrap(err) } - if err := s.handleRPC(stream, rpc); err != nil { + if err := s.handleRPC(ctx, stream, rpc); err != nil { return errs.Wrap(err) } } @@ -162,10 +226,16 @@ func (s *Server) Serve(ctx context.Context, lis net.Listener) (err error) { } // handleRPC handles the rpc that has been requested by the stream. -func (s *Server) handleRPC(stream *drpcstream.Stream, rpc string) (err error) { - err = s.handler.HandleRPC(stream, rpc) - if err != nil { - return errs.Wrap(stream.SendError(err)) +func (s *Server) handleRPC(ctx context.Context, stream *drpcstream.Stream, rpc string) (err error) { + var processingErr error + if s.opts.serverInt != nil { + processingErr = s.opts.serverInt(ctx, rpc, stream, s.handler) + } else { + processingErr = s.handler.HandleRPC(stream, rpc) + } + + if processingErr != nil { + return errs.Wrap(stream.SendError(processingErr)) } return errs.Wrap(stream.CloseSend()) } diff --git a/drpcserver/server_interceptor.go b/drpcserver/server_interceptor.go new file mode 100644 index 0000000..6b38228 --- /dev/null +++ b/drpcserver/server_interceptor.go @@ -0,0 +1,24 @@ +package drpcserver + +import ( + "context" + "storj.io/drpc" +) + +// HandlerFunc is an adapter to allow the use of ordinary functions as drpc.Handlers. +// If f is a function with the appropriate signature, HandlerFunc(f) is a +// drpc.Handler object that calls f. +type HandlerFunc func(stream drpc.Stream, rpc string) error + +// HandleRPC calls f(stream, rpc). +// It implements the drpc.Handler interface. +func (f HandlerFunc) HandleRPC(stream drpc.Stream, rpc string) error { + return f(stream, rpc) +} + +// ServerInterceptor is a function that intercepts the execution of a DRPC method on the server. +// It allows for cross-cutting concerns like logging, metrics, authentication, or request manipulation +// to be applied to RPCs. +// It is the responsibility of the interceptor to call handler.HandleRPC to continue +// processing the RPC, or to terminate the RPC by returning an error or handling it directly. +type ServerInterceptor func(ctx context.Context, rpc string, stream drpc.Stream, handler drpc.Handler) error diff --git a/drpcserver/server_interceptor_test.go b/drpcserver/server_interceptor_test.go new file mode 100644 index 0000000..eca7268 --- /dev/null +++ b/drpcserver/server_interceptor_test.go @@ -0,0 +1,263 @@ +package drpcserver + +import ( + "context" + "errors" + "strings" + "testing" + + "storj.io/drpc" +) + +// emptyMessage is a drpc.Message that carries no data. +type emptyMessage struct{} + +// DRPCMessage implements drpc.Message to satisfy the interface. +func (emptyMessage) DRPCMessage() {} + +// dummyEncoding is a no-op drpc.Encoding for testing purposes. +type dummyEncoding struct{} + +// Marshal implements drpc.Encoding. +func (d dummyEncoding) Marshal(msg drpc.Message) ([]byte, error) { + if msg == nil { + return nil, nil + } + // For an emptyMessage, we can return an empty byte slice. + if _, ok := msg.(*emptyMessage); ok { + return []byte{}, nil + } + return nil, errors.New("dummyEncoding can only marshal *emptyMessage") +} + +// Unmarshal implements drpc.Encoding. +func (d dummyEncoding) Unmarshal(data []byte, msg drpc.Message) error { + return nil +} + +// mockHandler is a mock drpc.Handler. +type mockHandler struct { + fn func(stream drpc.Stream, rpc string) error +} + +func (m *mockHandler) HandleRPC(stream drpc.Stream, rpc string) error { + if m.fn != nil { + return m.fn(stream, rpc) + } + return nil +} + +// mockStream is a mock drpc.Stream. +type mockStream struct { + ctx context.Context + + // Fields to track behavior for tests + msgSent []drpc.Message + msgRecvd []drpc.Message + closeSendCalled bool + closedCalled bool + sendErrorCalled bool + lastErrorSent error +} + +func (m *mockStream) Context() context.Context { + if m.ctx != nil { + return m.ctx + } + return context.Background() +} + +func (m *mockStream) MsgSend(msg drpc.Message, _ drpc.Encoding) error { + m.msgSent = append(m.msgSent, msg) + return nil +} + +func (m *mockStream) MsgRecv(_ drpc.Message, _ drpc.Encoding) error { + if len(m.msgRecvd) > 0 { + return nil + } + return errors.New("mockStream: no messages to receive") +} + +func (m *mockStream) CloseSend() error { + m.closeSendCalled = true + return nil +} + +func (m *mockStream) Close() error { + m.closedCalled = true + return nil +} + +func (m *mockStream) SendError(err error) error { + m.sendErrorCalled = true + m.lastErrorSent = err + return nil +} + +func TestWithChainServerInterceptor(t *testing.T) { + interceptor1 := func(ctx context.Context, rpc string, stream drpc.Stream, handler drpc.Handler) error { + return handler.HandleRPC(stream, rpc) + } + interceptor2 := func(ctx context.Context, rpc string, stream drpc.Stream, handler drpc.Handler) error { + return handler.HandleRPC(stream, rpc) + } + + opt := WithChainServerInterceptor(interceptor1, interceptor2) + opts := &Options{} + opt(opts) + + if opts.serverInts == nil || len(opts.serverInts) != 2 { + t.Fatal("serverInts should not be nil") + } +} + +func TestNewWithOptions_WithInterceptors(t *testing.T) { + interceptor := func(ctx context.Context, rpc string, stream drpc.Stream, handler drpc.Handler) error { + return handler.HandleRPC(stream, rpc) + } + // Create a mock handler instance + mockRPCHandler := &mockHandler{ + fn: func(stream drpc.Stream, rpc string) error { + return nil + }, + } + srv := NewWithOptions(mockRPCHandler, Options{}, WithChainServerInterceptor(interceptor)) + + if srv.opts.serverInt == nil { + t.Fatal("serverInt should not be nil in server options") + } +} + +func TestServer_handleRPC_InterceptorError(t *testing.T) { + expectedErr := errors.New("interceptor error") + + interceptor1 := func(ctx context.Context, rpc string, stream drpc.Stream, handler drpc.Handler) error { + return expectedErr // Error out before calling next + } + + interceptor2 := func(ctx context.Context, rpc string, stream drpc.Stream, handler drpc.Handler) error { + t.Error("interceptor2 should not be called") + return handler.HandleRPC(stream, rpc) + } + + handlerCalled := false + mockRPCHandler := &mockHandler{ + fn: func(stream drpc.Stream, rpc string) error { + handlerCalled = true + return nil + }, + } + + srv := NewWithOptions(mockRPCHandler, Options{}, WithChainServerInterceptor(interceptor1, interceptor2)) + finalInterceptor := srv.opts.serverInt + if finalInterceptor == nil { + t.Fatal("serverInt is nil after NewWithOptions") + } + + err := finalInterceptor(context.Background(), "TestRPC", &mockStream{}, mockRPCHandler) + if err == nil { + t.Fatal("expected an error from interceptor chain, got nil") + } + if !strings.Contains(err.Error(), expectedErr.Error()) { + t.Errorf("expected error '%v', got '%v'", expectedErr, err) + } + + if handlerCalled { + t.Error("handler should not have been called when an interceptor errors") + } +} + +type contextKey string + +const testCtxKey = contextKey("testKey") + +func TestServer_handleRPC_InterceptorContextPropagation(t *testing.T) { + var ( + valFromCtx1 interface{} + valFromCtx2 interface{} + valFromCtxHandler interface{} + ) + + interceptor1 := func(ctx context.Context, rpc string, stream drpc.Stream, handler drpc.Handler) error { + newCtx := context.WithValue(ctx, testCtxKey, "value1") + return handler.HandleRPC(&mockStream{ctx: newCtx}, rpc) // Pass modified context via stream + } + + interceptor2 := func(ctx context.Context, rpc string, stream drpc.Stream, handler drpc.Handler) error { + valFromCtx1 = stream.Context().Value(testCtxKey) + newCtx := context.WithValue(stream.Context(), testCtxKey, "value2") + return handler.HandleRPC(&mockStream{ctx: newCtx}, rpc) // Pass modified context via stream + } + + mockRPCHandler := &mockHandler{ + fn: func(stream drpc.Stream, rpc string) error { + valFromCtx2 = stream.Context().Value(testCtxKey) // This should be "value2" + // Simulate handler using the context from the stream + valFromCtxHandler = stream.Context().Value(testCtxKey) + return nil + }, + } + + srv := NewWithOptions(mockRPCHandler, Options{}, WithChainServerInterceptor(interceptor1, interceptor2)) + finalInterceptor := srv.opts.serverInt + if finalInterceptor == nil { + t.Fatal("serverInt is nil after NewWithOptions") + } + + // We pass the initial context via the mockStream passed to the finalInterceptor + initialCtx := context.Background() + err := finalInterceptor(initialCtx, "TestRPC", &mockStream{ctx: initialCtx}, mockRPCHandler) + if err != nil { + t.Fatalf("interceptor chain returned an error: %v", err) + } + + if valFromCtx1 != "value1" { + t.Errorf("expected value 'value1' from context in interceptor2, got '%v'", valFromCtx1) + } + if valFromCtx2 != "value2" { + t.Errorf("expected value 'value2' from context in handler (set by interceptor2), got '%v'", valFromCtx2) + } + if valFromCtxHandler != "value2" { + t.Errorf("expected value 'value2' from context in handler (final value), got '%v'", valFromCtxHandler) + } +} + +func TestServer_handleRPC_NoInterceptors(t *testing.T) { + handlerCalled := false + mockRPCHandler := &mockHandler{ + fn: func(stream drpc.Stream, rpc string) error { + handlerCalled = true + return nil + }, + } + + // Create server without interceptors + srv := NewWithOptions(mockRPCHandler, Options{}) + if srv.opts.serverInt != nil { + t.Error("serverInt should be nil when no interceptors are provided") + } + + // check the behavior of the interceptor application logic + // when no interceptors are configured. + opts := &Options{} // Default options, serverInt should be nil + + var effectiveHandler drpc.Handler = mockRPCHandler + if opts.serverInt != nil { + // This block should not be hit if no interceptors are set + err := opts.serverInt(context.Background(), "TestRPC", &mockStream{}, mockRPCHandler) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + } else { + // This is the expected path if no interceptors are set + err := effectiveHandler.HandleRPC(&mockStream{}, "TestRPC") + if err != nil { + t.Fatalf("handler returned an error: %v", err) + } + } + + if !handlerCalled { + t.Error("handler was not called when no interceptors are present") + } +} From 5441f897f33f35f744d071967dc012f7bec0e549 Mon Sep 17 00:00:00 2001 From: Shaik Zakir Hussain Date: Wed, 11 Jun 2025 11:13:16 +0530 Subject: [PATCH 2/2] fixup! drpcserver: add server interceptor support in drpc --- drpcserver/server_interceptor_test.go | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/drpcserver/server_interceptor_test.go b/drpcserver/server_interceptor_test.go index eca7268..9b629a7 100644 --- a/drpcserver/server_interceptor_test.go +++ b/drpcserver/server_interceptor_test.go @@ -174,9 +174,8 @@ const testCtxKey = contextKey("testKey") func TestServer_handleRPC_InterceptorContextPropagation(t *testing.T) { var ( - valFromCtx1 interface{} - valFromCtx2 interface{} - valFromCtxHandler interface{} + valFromCtx1 interface{} + valFromCtx2 interface{} ) interceptor1 := func(ctx context.Context, rpc string, stream drpc.Stream, handler drpc.Handler) error { @@ -193,8 +192,6 @@ func TestServer_handleRPC_InterceptorContextPropagation(t *testing.T) { mockRPCHandler := &mockHandler{ fn: func(stream drpc.Stream, rpc string) error { valFromCtx2 = stream.Context().Value(testCtxKey) // This should be "value2" - // Simulate handler using the context from the stream - valFromCtxHandler = stream.Context().Value(testCtxKey) return nil }, } @@ -218,9 +215,6 @@ func TestServer_handleRPC_InterceptorContextPropagation(t *testing.T) { if valFromCtx2 != "value2" { t.Errorf("expected value 'value2' from context in handler (set by interceptor2), got '%v'", valFromCtx2) } - if valFromCtxHandler != "value2" { - t.Errorf("expected value 'value2' from context in handler (final value), got '%v'", valFromCtxHandler) - } } func TestServer_handleRPC_NoInterceptors(t *testing.T) {