Skip to content

Commit

Permalink
feat: add a unique ID (req_id) to all logs related to a request
Browse files Browse the repository at this point in the history
  • Loading branch information
ThinkChaos committed Mar 19, 2024
1 parent 641512e commit a6d2ca3
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 66 deletions.
88 changes: 60 additions & 28 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/0xERR0R/blocky/redis"
"github.com/0xERR0R/blocky/resolver"
"github.com/0xERR0R/blocky/util"
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"

"github.com/go-chi/chi/v5"
Expand Down Expand Up @@ -417,15 +418,11 @@ func createQueryResolver(
}

func (s *Server) registerDNSHandlers(ctx context.Context) {
wrappedOnRequest := func(w dns.ResponseWriter, request *dns.Msg) {
ip, proto := resolveClientIPAndProtocol(w.RemoteAddr())

s.OnRequest(ctx, w, ip, proto, request)
}

for _, server := range s.dnsServers {
handler := server.Handler.(*dns.ServeMux)
handler.HandleFunc(".", wrappedOnRequest)
handler.HandleFunc(".", func(w dns.ResponseWriter, m *dns.Msg) {
s.OnRequest(ctx, w, m)
})
handler.HandleFunc("healthcheck.blocky", func(w dns.ResponseWriter, m *dns.Msg) {
s.OnHealthCheck(ctx, w, m)
})
Expand Down Expand Up @@ -570,43 +567,77 @@ func newRequest(
protocol model.RequestProtocol, request *dns.Msg,
) (context.Context, *model.Request) {
ctx, logger := log.CtxWithFields(ctx, logrus.Fields{
"req_id": uuid.New().String(),
"question": util.QuestionToString(request.Question),
"client_ip": clientIP,
})

return ctx, &model.Request{
logger.WithFields(logrus.Fields{
"client_request_id": request.Id,
"client_id": clientID,
"protocol": protocol,
}).Trace("new incoming request")

req := model.Request{
ClientIP: clientIP,
RequestClientID: clientID,
Protocol: protocol,
Req: request,
Log: logger,
RequestTS: time.Now(),
}

return ctx, &req
}

// OnRequest will be executed if a new DNS request is received
func (s *Server) OnRequest(
ctx context.Context, w dns.ResponseWriter,
clientIP net.IP, protocol model.RequestProtocol,
request *dns.Msg,
) {
logger().Debug("new request")
func newRequestFromDNS(ctx context.Context, rw dns.ResponseWriter, msg *dns.Msg) (context.Context, *model.Request) {
var (
clientIP net.IP
protocol model.RequestProtocol
)

var hostName string
if rw != nil {
clientIP, protocol = resolveClientIPAndProtocol(rw.RemoteAddr())
}

con, ok := w.(dns.ConnectionStater)
if ok && con.ConnectionState() != nil {
hostName = con.ConnectionState().ServerName
var clientID string
if con, ok := rw.(dns.ConnectionStater); ok && con.ConnectionState() != nil {
clientID = extractClientIDFromHost(con.ConnectionState().ServerName)
}

ctx, req := newRequest(ctx, clientIP, extractClientIDFromHost(hostName), protocol, request)
return newRequest(ctx, clientIP, clientID, protocol, msg)
}

func newRequestFromHTTP(ctx context.Context, req *http.Request, msg *dns.Msg) (context.Context, *model.Request) {
protocol := model.RequestProtocolTCP
clientIP := util.HTTPClientIP(req)

clientID := chi.URLParam(req, "clientID")
if clientID == "" {
clientID = extractClientIDFromHost(req.Host)
}

return newRequest(ctx, clientIP, clientID, protocol, msg)
}

// OnRequest will be executed if a new DNS request is received
func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, msg *dns.Msg) {
ctx, request := newRequestFromDNS(ctx, w, msg)

s.handleReq(ctx, request, w)
}

type msgWriter interface {
WriteMsg(msg *dns.Msg) error
}

response, err := s.resolve(ctx, req)
func (s *Server) handleReq(ctx context.Context, request *model.Request, w msgWriter) {
response, err := s.resolve(ctx, request)
if err != nil {
logger().Error("error on processing request:", err)
log.FromCtx(ctx).Error("error on processing request:", err)

m := new(dns.Msg)
m.SetRcode(request, dns.RcodeServerFailure)
m.SetRcode(request.Req, dns.RcodeServerFailure)
err := w.WriteMsg(m)
util.LogOnError(ctx, "can't write message: ", err)
} else {
Expand Down Expand Up @@ -634,7 +665,7 @@ func (s *Server) resolve(ctx context.Context, request *model.Request) (response
m := new(dns.Msg)
m.SetRcode(request.Req, dns.RcodeFormatError)

request.Log.Error("query has no questions")
log.FromCtx(ctx).Error("query has no questions")

response = &model.Response{Res: m, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}
default:
Expand Down Expand Up @@ -688,10 +719,11 @@ func (s *Server) OnHealthCheck(ctx context.Context, w dns.ResponseWriter, reques
}

func resolveClientIPAndProtocol(addr net.Addr) (ip net.IP, protocol model.RequestProtocol) {
if t, ok := addr.(*net.UDPAddr); ok {
return t.IP, model.RequestProtocolUDP
} else if t, ok := addr.(*net.TCPAddr); ok {
return t.IP, model.RequestProtocolTCP
switch a := addr.(type) {
case *net.UDPAddr:
return a.IP, model.RequestProtocolUDP
case *net.TCPAddr:
return a.IP, model.RequestProtocolTCP
}

return nil, model.RequestProtocolUDP
Expand Down
59 changes: 21 additions & 38 deletions server/server_endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func (s *Server) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request
s.processDohMessage(rawMsg, rw, req)
}

func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *http.Request) {
func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) {
msg := new(dns.Msg)
if err := msg.Unpack(rawMsg); err != nil {
logger().Error("can't deserialize message: ", err)
Expand All @@ -141,57 +141,40 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
return
}

rw.Header().Set("content-type", dnsContentType)
ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg)

writeErr := func(err error) {
log.Log().Error(err)

msg := new(dns.Msg)
msg.SetRcode(msg, dns.RcodeServerFailure)

buff, err := msg.Pack()
if err != nil {
return
}

// https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
rw.WriteHeader(http.StatusOK)

_, _ = rw.Write(buff)
}

clientID := chi.URLParam(req, "clientID")
if clientID == "" {
clientID = extractClientIDFromHost(req.Host)
}
s.handleReq(ctx, dnsReq, httpMsgWriter{rw})
}

ctx, r := newRequest(req.Context(), util.HTTPClientIP(req), clientID, model.RequestProtocolTCP, msg)
type httpMsgWriter struct {
rw http.ResponseWriter
}

resResponse, err := s.resolve(ctx, r)
func (r httpMsgWriter) WriteMsg(msg *dns.Msg) error {
b, err := msg.Pack()
if err != nil {
writeErr(fmt.Errorf("unable to process query: %w", err))

return
return err
}

b, err := resResponse.Res.Pack()
if err != nil {
writeErr(fmt.Errorf("can't serialize message: %w", err))
r.rw.Header().Set("content-type", dnsContentType)

return
}
// https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
r.rw.WriteHeader(http.StatusOK)

_, err = rw.Write(b)
log.Log().Error(fmt.Errorf("can't write response: %w", err))
_, err = r.rw.Write(b)

return err
}

func (s *Server) Query(
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
) (*model.Response, error) {
dnsRequest := util.NewMsgWithQuestion(question, qType)
ctx, r := newRequest(ctx, clientIP, extractClientIDFromHost(serverHost), model.RequestProtocolTCP, dnsRequest)
msg := util.NewMsgWithQuestion(question, qType)
clientID := extractClientIDFromHost(serverHost)

ctx, req := newRequest(ctx, clientIP, clientID, model.RequestProtocolTCP, msg)

return s.resolve(ctx, r)
return s.resolve(ctx, req)
}

func createHTTPSRouter(cfg *config.Config) *chi.Mux {
Expand Down

0 comments on commit a6d2ca3

Please sign in to comment.