Skip to content

Commit 9f435ef

Browse files
committed
Streamed replies are now handled by Request
1 parent c59a3ad commit 9f435ef

File tree

4 files changed

+78
-156
lines changed

4 files changed

+78
-156
lines changed

examples/alloptions/alloptions.nrpc.go

+21-102
Original file line numberDiff line numberDiff line change
@@ -51,70 +51,6 @@ func (h *SvcCustomSubjectHandler) MtNoRequestPublish(pkginstance string, msg Sim
5151
return h.nc.Publish(subject, rawMsg)
5252
}
5353

54-
func (h *SvcCustomSubjectHandler) MtStreamedReplyHandler(
55-
ctx context.Context, request *nrpc.Request, req StringArg) {
56-
ctx, cancel := context.WithCancel(ctx)
57-
58-
keepStreamAlive := nrpc.NewKeepStreamAlive(
59-
request.Conn, request.ReplySubject, request.Encoding, cancel,
60-
)
61-
62-
var msgCount uint32
63-
64-
_, nrpcErr := nrpc.CaptureErrors(func() (proto.Message, error) {
65-
err := h.server.MtStreamedReply(ctx, req, func(rep SimpleStringReply){
66-
if err := request.SendReply(&rep, nil); err != nil {
67-
log.Printf("nrpc: error publishing response")
68-
cancel()
69-
return
70-
}
71-
msgCount++
72-
})
73-
return nil, err
74-
})
75-
keepStreamAlive.Stop()
76-
77-
if nrpcErr != nil {
78-
request.SendReply(nil, nrpcErr)
79-
} else {
80-
request.SendReply(
81-
nil, &nrpc.Error{Type: nrpc.Error_EOS, MsgCount: msgCount},
82-
)
83-
}
84-
}
85-
86-
func (h *SvcCustomSubjectHandler) MtVoidReqStreamedReplyHandler(
87-
ctx context.Context, request *nrpc.Request) {
88-
ctx, cancel := context.WithCancel(ctx)
89-
90-
keepStreamAlive := nrpc.NewKeepStreamAlive(
91-
request.Conn, request.ReplySubject, request.Encoding, cancel,
92-
)
93-
94-
var msgCount uint32
95-
96-
_, nrpcErr := nrpc.CaptureErrors(func() (proto.Message, error) {
97-
err := h.server.MtVoidReqStreamedReply(ctx, func(rep SimpleStringReply){
98-
if err := request.SendReply(&rep, nil); err != nil {
99-
log.Printf("nrpc: error publishing response")
100-
cancel()
101-
return
102-
}
103-
msgCount++
104-
})
105-
return nil, err
106-
})
107-
keepStreamAlive.Stop()
108-
109-
if nrpcErr != nil {
110-
request.SendReply(nil, nrpcErr)
111-
} else {
112-
request.SendReply(
113-
nil, &nrpc.Error{Type: nrpc.Error_EOS, MsgCount: msgCount},
114-
)
115-
}
116-
}
117-
11854
func (h *SvcCustomSubjectHandler) Handler(msg *nats.Msg) {
11955
request := nrpc.NewRequest(h.ctx, h.nc, msg.Subject, msg.Reply)
12056
// extract method name & encoding from subject
@@ -194,8 +130,13 @@ func (h *SvcCustomSubjectHandler) Handler(msg *nats.Msg) {
194130
Message: "bad request received: " + err.Error(),
195131
}
196132
} else {
197-
h.MtStreamedReplyHandler(h.ctx, request, req)
198-
return
133+
request.SetupStreamedReply()
134+
request.Handler = func(ctx context.Context)(proto.Message, error){
135+
err := h.server.MtStreamedReply(ctx, req, func(rep SimpleStringReply){
136+
request.SendStreamReply(&rep)
137+
})
138+
return nil, err
139+
}
199140
}
200141
case "mtvoidreqstreamedreply":
201142
_, request.Encoding, err = nrpc.ParseSubjectTail(0, request.SubjectTail)
@@ -211,8 +152,13 @@ func (h *SvcCustomSubjectHandler) Handler(msg *nats.Msg) {
211152
Message: "bad request received: " + err.Error(),
212153
}
213154
} else {
214-
h.MtVoidReqStreamedReplyHandler(h.ctx, request)
215-
return
155+
request.SetupStreamedReply()
156+
request.Handler = func(ctx context.Context)(proto.Message, error){
157+
err := h.server.MtVoidReqStreamedReply(ctx, func(rep SimpleStringReply){
158+
request.SendStreamReply(&rep)
159+
})
160+
return nil, err
161+
}
216162
}
217163
default:
218164
log.Printf("SvcCustomSubjectHandler: unknown name %q", name)
@@ -428,38 +374,6 @@ func (h *SvcSubjectParamsHandler) Subject() string {
428374
return "root.*.svcsubjectparams.*.>"
429375
}
430376

431-
func (h *SvcSubjectParamsHandler) MtStreamedReplyWithSubjectParamsHandler(
432-
ctx context.Context, request *nrpc.Request, mtParams []string) {
433-
ctx, cancel := context.WithCancel(ctx)
434-
435-
keepStreamAlive := nrpc.NewKeepStreamAlive(
436-
request.Conn, request.ReplySubject, request.Encoding, cancel,
437-
)
438-
439-
var msgCount uint32
440-
441-
_, nrpcErr := nrpc.CaptureErrors(func() (proto.Message, error) {
442-
err := h.server.MtStreamedReplyWithSubjectParams(ctx, mtParams[0], mtParams[1], func(rep SimpleStringReply){
443-
if err := request.SendReply(&rep, nil); err != nil {
444-
log.Printf("nrpc: error publishing response")
445-
cancel()
446-
return
447-
}
448-
msgCount++
449-
})
450-
return nil, err
451-
})
452-
keepStreamAlive.Stop()
453-
454-
if nrpcErr != nil {
455-
request.SendReply(nil, nrpcErr)
456-
} else {
457-
request.SendReply(
458-
nil, &nrpc.Error{Type: nrpc.Error_EOS, MsgCount: msgCount},
459-
)
460-
}
461-
}
462-
463377
func (h *SvcSubjectParamsHandler) MtNoRequestWParamsPublish(pkginstance string, svcclientid string, mtmp1 string, msg SimpleStringReply) error {
464378
rawMsg, err := nrpc.Marshal("protobuf", &msg)
465379
if err != nil {
@@ -526,8 +440,13 @@ func (h *SvcSubjectParamsHandler) Handler(msg *nats.Msg) {
526440
Message: "bad request received: " + err.Error(),
527441
}
528442
} else {
529-
h.MtStreamedReplyWithSubjectParamsHandler(h.ctx, request, mtParams)
530-
return
443+
request.SetupStreamedReply()
444+
request.Handler = func(ctx context.Context)(proto.Message, error){
445+
err := h.server.MtStreamedReplyWithSubjectParams(ctx, mtParams[0], mtParams[1], func(rep SimpleStringReply){
446+
request.SendStreamReply(&rep)
447+
})
448+
return nil, err
449+
}
531450
}
532451
case "mtnoreply":
533452
request.NoReply = true

examples/alloptions/alloptions_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ func TestAll(t *testing.T) {
158158
context.Background(),
159159
StringArg{Arg1: "arg"},
160160
func(ctx context.Context, rep SimpleStringReply) {
161+
fmt.Println("received", rep)
161162
resList = append(resList, rep.GetReply())
162163
})
163164
if err != nil {

nrpc.go

+46-1
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,11 @@ type Request struct {
260260
Context context.Context
261261
Conn NatsConn
262262

263+
KeepStreamAlive *KeepStreamAlive
264+
StreamContext context.Context
265+
StreamCancel func()
266+
StreamMsgCount uint32
267+
263268
Subject string
264269
MethodName string
265270
SubjectTail []string
@@ -286,7 +291,11 @@ func (r Request) Elapsed() time.Duration {
286291
// that should be returned to the caller
287292
func (r Request) Run() (msg proto.Message, replyError *Error) {
288293
r.StartedAt = time.Now()
289-
ctx := context.WithValue(r.Context, RequestContextKey, &r)
294+
ctx := r.Context
295+
if r.StreamedReply() {
296+
ctx = r.StreamContext
297+
}
298+
ctx = context.WithValue(ctx, RequestContextKey, &r)
290299
msg, replyError = CaptureErrors(
291300
func() (proto.Message, error) {
292301
return r.Handler(ctx)
@@ -326,8 +335,44 @@ func (r *Request) SetServiceParam(key, value string) {
326335
r.ServiceParams[key] = value
327336
}
328337

338+
// SetupStreamedReply initializes the reply stream
339+
func (r *Request) SetupStreamedReply() {
340+
r.StreamContext, r.StreamCancel = context.WithCancel(r.Context)
341+
r.KeepStreamAlive = NewKeepStreamAlive(
342+
r.Conn, r.ReplySubject, r.Encoding, r.StreamCancel)
343+
}
344+
345+
// StreamedReply returns true if the request reply is streamed
346+
func (r Request) StreamedReply() bool {
347+
return r.KeepStreamAlive != nil
348+
}
349+
350+
// SendStreamReply send a reply a part of a stream
351+
func (r *Request) SendStreamReply(msg proto.Message) {
352+
log.Printf("nrpc: SendStreamReply")
353+
if err := r.sendReply(msg, nil); err != nil {
354+
log.Printf("nrpc: error publishing response")
355+
r.StreamCancel()
356+
return
357+
}
358+
r.StreamMsgCount++
359+
}
360+
329361
// SendReply sends a reply to the caller
330362
func (r *Request) SendReply(resp proto.Message, withError *Error) error {
363+
if r.StreamedReply() {
364+
r.KeepStreamAlive.Stop()
365+
if withError == nil {
366+
return r.sendReply(
367+
nil, &Error{Type: Error_EOS, MsgCount: r.StreamMsgCount},
368+
)
369+
}
370+
}
371+
return r.sendReply(resp, withError)
372+
}
373+
374+
// sendReply sends a reply to the caller
375+
func (r *Request) sendReply(resp proto.Message, withError *Error) error {
331376
return Publish(resp, withError, r.Conn, r.ReplySubject, r.Encoding)
332377
}
333378

protoc-gen-nrpc/tmpl.go

+10-53
Original file line numberDiff line numberDiff line change
@@ -155,54 +155,6 @@ func (h *{{$serviceName}}Handler) {{.GetName}}Publish(
155155
return h.nc.Publish(subject, rawMsg)
156156
}
157157
{{- end}}
158-
{{- if HasStreamedReply .}}
159-
160-
func (h *{{$serviceName}}Handler) {{.GetName}}Handler(
161-
ctx context.Context, request *nrpc.Request
162-
{{- if GetMethodSubjectParams . -}}
163-
, mtParams []string
164-
{{- end -}}
165-
{{- if ne .GetInputType ".nrpc.Void" -}}
166-
, req {{GoType .GetInputType}}
167-
{{- end -}}
168-
) {
169-
ctx, cancel := context.WithCancel(ctx)
170-
171-
keepStreamAlive := nrpc.NewKeepStreamAlive(
172-
request.Conn, request.ReplySubject, request.Encoding, cancel,
173-
)
174-
175-
var msgCount uint32
176-
177-
_, nrpcErr := nrpc.CaptureErrors(func() (proto.Message, error) {
178-
err := h.server.{{.GetName}}(ctx
179-
{{- range $i, $p := GetMethodSubjectParams . -}}
180-
, mtParams[{{ $i }}]
181-
{{- end -}}
182-
{{- if ne .GetInputType ".nrpc.Void" -}}
183-
, req
184-
{{- end -}}
185-
, func(rep {{GoType .GetOutputType}}){
186-
if err := request.SendReply(&rep, nil); err != nil {
187-
log.Printf("nrpc: error publishing response")
188-
cancel()
189-
return
190-
}
191-
msgCount++
192-
})
193-
return nil, err
194-
})
195-
keepStreamAlive.Stop()
196-
197-
if nrpcErr != nil {
198-
request.SendReply(nil, nrpcErr)
199-
} else {
200-
request.SendReply(
201-
nil, &nrpc.Error{Type: nrpc.Error_EOS, MsgCount: msgCount},
202-
)
203-
}
204-
}
205-
{{- end}}
206158
{{- end}}
207159
208160
{{- if ServiceNeedsHandler .}}
@@ -262,15 +214,20 @@ func (h *{{.GetName}}Handler) Handler(msg *nats.Msg) {
262214
{{- end}}
263215
} else {
264216
{{- if HasStreamedReply .}}
265-
h.{{.GetName}}Handler(h.ctx, request
266-
{{- if GetMethodSubjectParams . -}}
267-
, mtParams
217+
request.SetupStreamedReply()
218+
request.Handler = func(ctx context.Context)(proto.Message, error){
219+
err := h.server.{{.GetName}}(ctx
220+
{{- range $i, $p := GetMethodSubjectParams . -}}
221+
, mtParams[{{ $i }}]
268222
{{- end -}}
269223
{{- if ne .GetInputType ".nrpc.Void" -}}
270224
, req
271225
{{- end -}}
272-
)
273-
return
226+
, func(rep {{GoType .GetOutputType}}){
227+
request.SendStreamReply(&rep)
228+
})
229+
return nil, err
230+
}
274231
{{- else }}
275232
request.Handler = func(ctx context.Context)(proto.Message, error){
276233
{{- if eq .GetOutputType ".nrpc.NoReply" -}}

0 commit comments

Comments
 (0)