Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tcpreuse: fix Scope() for *tls.Conn #3181

Merged
merged 3 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions p2p/test/transport/gating_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,26 @@ import (

//go:generate go run go.uber.org/mock/mockgen -package transport_integration -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p/core/connmgr ConnectionGater

func stripCertHash(addr ma.Multiaddr) ma.Multiaddr {
// normalize removes the certhash and replaces /wss with /tls/ws
func normalize(addr ma.Multiaddr) ma.Multiaddr {
for {
if _, err := addr.ValueForProtocol(ma.P_CERTHASH); err != nil {
break
}
addr, _ = ma.SplitLast(addr)
}
return addr

// replace /wss with /tls/ws
components := []ma.Multiaddr{}
ma.ForEach(addr, func(c ma.Component) bool {
if c.Protocol().Code == ma.P_WSS {
components = append(components, ma.StringCast("/tls/ws"))
} else {
components = append(components, &c)
}
return true
})
return ma.Join(components...)
}

func addrPort(addr ma.Multiaddr) netip.AddrPort {
Expand Down Expand Up @@ -119,8 +131,7 @@ func TestInterceptSecuredOutgoing(t *testing.T) {
connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true),
connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport and WebRTC addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.RemoteMultiaddr()))
}),
)
err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})
Expand Down Expand Up @@ -154,8 +165,7 @@ func TestInterceptUpgradedOutgoing(t *testing.T) {
connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.RemoteMultiaddr())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(c.RemoteMultiaddr()))
require.Equal(t, h1.ID(), c.LocalPeer())
require.Equal(t, h2.ID(), c.RemotePeer())
}))
Expand Down Expand Up @@ -189,17 +199,15 @@ func TestInterceptAccept(t *testing.T) {
// In WebRTC, retransmissions of the STUN packet might cause us to create multiple connections,
// if the first connection attempt is rejected.
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr()))
}).AnyTimes()
} else if strings.Contains(tc.Name, "WebSocket-Shared") {
} else if strings.Contains(tc.Name, "WebSocket-Shared") || strings.Contains(tc.Name, "WebSocket-Secured-Shared") {
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
require.Equal(t, addrPort(h2.Addrs()[0]), addrPort(addrs.LocalMultiaddr()))
})
} else {
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr(), "%s\n%s", h2.Addrs()[0], addrs.LocalMultiaddr())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr()))
})
}

Expand Down Expand Up @@ -236,8 +244,7 @@ func TestInterceptSecuredIncoming(t *testing.T) {
gomock.InOrder(
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr()))
}),
)
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
Expand Down Expand Up @@ -270,8 +277,7 @@ func TestInterceptUpgradedIncoming(t *testing.T) {
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.LocalMultiaddr())
require.Equal(t, normalize(h2.Addrs()[0]), normalize(c.LocalMultiaddr()))
require.Equal(t, h1.ID(), c.RemotePeer())
require.Equal(t, h2.ID(), c.LocalPeer())
}),
Expand Down
89 changes: 84 additions & 5 deletions p2p/test/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@ package transport_integration
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"fmt"
"io"
"math/big"
"net"
"runtime"
"strings"
Expand All @@ -15,6 +21,8 @@ import (
"testing"
"time"

libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"

"github.com/libp2p/go-libp2p"
"github.com/libp2p/go-libp2p/config"
"github.com/libp2p/go-libp2p/core/connmgr"
Expand All @@ -30,9 +38,9 @@ import (
"github.com/libp2p/go-libp2p/p2p/net/swarm"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
"github.com/libp2p/go-libp2p/p2p/security/noise"
tls "github.com/libp2p/go-libp2p/p2p/security/tls"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
"github.com/libp2p/go-libp2p/p2p/transport/websocket"
"go.uber.org/mock/gomock"

ma "github.com/multiformats/go-multiaddr"
Expand Down Expand Up @@ -68,6 +76,44 @@ func transformOpts(opts TransportTestCaseOpts) []config.Option {
return libp2pOpts
}

func selfSignedTLSConfig(t *testing.T) *tls.Config {
t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)

notBefore := time.Now()
notAfter := notBefore.Add(365 * 24 * time.Hour)

serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
require.NoError(t, err)

certTemplate := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Test"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}

derBytes, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &priv.PublicKey, priv)
require.NoError(t, err)

cert := tls.Certificate{
Certificate: [][]byte{derBytes},
PrivateKey: priv,
}

tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
return tlsConfig
}

var transportsToTest = []TransportTestCase{
{
Name: "TCP / Noise / Yamux",
Expand All @@ -89,7 +135,7 @@ var transportsToTest = []TransportTestCase{
Name: "TCP / TLS / Yamux",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New))
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
if opts.NoListen {
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
Expand All @@ -106,7 +152,7 @@ var transportsToTest = []TransportTestCase{
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New))
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
if opts.NoListen {
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
Expand All @@ -123,7 +169,7 @@ var transportsToTest = []TransportTestCase{
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New))
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
libp2pOpts = append(libp2pOpts, libp2p.Transport(tcp.NewTCPTransport, tcp.WithMetrics()))
if opts.NoListen {
Expand All @@ -140,7 +186,7 @@ var transportsToTest = []TransportTestCase{
Name: "TCP-WithMetrics / TLS / Yamux",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New))
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
libp2pOpts = append(libp2pOpts, libp2p.Transport(tcp.NewTCPTransport, tcp.WithMetrics()))
if opts.NoListen {
Expand Down Expand Up @@ -168,6 +214,23 @@ var transportsToTest = []TransportTestCase{
return h
},
},
{
Name: "WebSocket-Secured-Shared",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
if opts.NoListen {
config := tls.Config{InsecureSkipVerify: true}
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs, libp2p.Transport(websocket.New, websocket.WithTLSClientConfig(&config)))
} else {
config := selfSignedTLSConfig(t)
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/sni/localhost/tls/ws"), libp2p.Transport(websocket.New, websocket.WithTLSConfig(config)))
}
h, err := libp2p.New(libp2pOpts...)
require.NoError(t, err)
return h
},
},
{
Name: "WebSocket",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
Expand All @@ -182,6 +245,22 @@ var transportsToTest = []TransportTestCase{
return h
},
},
{
Name: "WebSocket-Secured",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
if opts.NoListen {
config := tls.Config{InsecureSkipVerify: true}
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs, libp2p.Transport(websocket.New, websocket.WithTLSClientConfig(&config)))
} else {
config := selfSignedTLSConfig(t)
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/sni/localhost/tls/ws"), libp2p.Transport(websocket.New, websocket.WithTLSConfig(config)))
}
h, err := libp2p.New(libp2pOpts...)
require.NoError(t, err)
return h
},
},
{
Name: "QUIC",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
Expand Down
5 changes: 5 additions & 0 deletions p2p/transport/tcpreuse/connwithscope.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ func (c connWithScope) Scope() network.ConnManagementScope {
return c.scope
}

func (c *connWithScope) Close() error {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is required for failures in the http tls handshake for wss. The http server closes the connection when handshake fails.

c.scope.Done()
return c.ManetTCPConnInterface.Close()
}

func manetConnWithScope(c manet.Conn, scope network.ConnManagementScope) (manet.Conn, error) {
if tcpconn, ok := c.(sampledconn.ManetTCPConnInterface); ok {
return &connWithScope{tcpconn, scope}, nil
Expand Down
1 change: 1 addition & 0 deletions p2p/transport/tcpreuse/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ func (m *multiplexedListener) run() error {
select {
case demux.buffer <- connWithScope:
case <-ctx.Done():
log.Debug("accept queue full; dropping connection from: %v", connWithScope.RemoteMultiaddr())
connWithScope.Close()
}
}()
Expand Down
8 changes: 8 additions & 0 deletions p2p/transport/websocket/conn.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package websocket

import (
"crypto/tls"
"errors"
"io"
"net"
Expand Down Expand Up @@ -142,6 +143,13 @@ func (c *Conn) Scope() network.ConnManagementScope {
}); ok {
return sc.Scope()
}
if nc, ok := nc.(*tls.Conn); ok {
if sc, ok := nc.NetConn().(interface {
Scope() network.ConnManagementScope
}); ok {
return sc.Scope()
}
}
return nil
}

Expand Down
Loading