diff --git a/p2p/test/transport/gating_test.go b/p2p/test/transport/gating_test.go index 99ce67b521..a26378357a 100644 --- a/p2p/test/transport/gating_test.go +++ b/p2p/test/transport/gating_test.go @@ -22,14 +22,26 @@ import ( //go:generate go run go.uber.org/mock/mockgen -package transport_integration -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p/core/connmgr ConnectionGater -func stripCertHash(addr ma.Multiaddr) ma.Multiaddr { +// normalize removes the certhash and replaces /wss with /tls/ws +func normalize(addr ma.Multiaddr) ma.Multiaddr { for { if _, err := addr.ValueForProtocol(ma.P_CERTHASH); err != nil { break } addr, _ = ma.SplitLast(addr) } - return addr + + // replace /wss with /tls/ws + components := []ma.Multiaddr{} + ma.ForEach(addr, func(c ma.Component) bool { + if c.Protocol().Code == ma.P_WSS { + components = append(components, ma.StringCast("/tls/ws")) + } else { + components = append(components, &c) + } + return true + }) + return ma.Join(components...) } func addrPort(addr ma.Multiaddr) netip.AddrPort { @@ -119,8 +131,7 @@ func TestInterceptSecuredOutgoing(t *testing.T) { connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true), connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true), connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { - // remove the certhash component from WebTransport and WebRTC addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String()) + require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.RemoteMultiaddr())) }), ) err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}) @@ -154,8 +165,7 @@ func TestInterceptUpgradedOutgoing(t *testing.T) { connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true), connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true), connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) { - // remove the certhash component from WebTransport addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]), c.RemoteMultiaddr()) + require.Equal(t, normalize(h2.Addrs()[0]), normalize(c.RemoteMultiaddr())) require.Equal(t, h1.ID(), c.LocalPeer()) require.Equal(t, h2.ID(), c.RemotePeer()) })) @@ -189,17 +199,15 @@ func TestInterceptAccept(t *testing.T) { // In WebRTC, retransmissions of the STUN packet might cause us to create multiple connections, // if the first connection attempt is rejected. connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { - // remove the certhash component from WebTransport addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) + require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr())) }).AnyTimes() - } else if strings.Contains(tc.Name, "WebSocket-Shared") { + } else if strings.Contains(tc.Name, "WebSocket-Shared") || strings.Contains(tc.Name, "WebSocket-Secured-Shared") { connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { require.Equal(t, addrPort(h2.Addrs()[0]), addrPort(addrs.LocalMultiaddr())) }) } else { connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { - // remove the certhash component from WebTransport addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr(), "%s\n%s", h2.Addrs()[0], addrs.LocalMultiaddr()) + require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr())) }) } @@ -236,8 +244,7 @@ func TestInterceptSecuredIncoming(t *testing.T) { gomock.InOrder( connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true), connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { - // remove the certhash component from WebTransport addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) + require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr())) }), ) h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) @@ -270,8 +277,7 @@ func TestInterceptUpgradedIncoming(t *testing.T) { connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true), connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true), connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) { - // remove the certhash component from WebTransport addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]), c.LocalMultiaddr()) + require.Equal(t, normalize(h2.Addrs()[0]), normalize(c.LocalMultiaddr())) require.Equal(t, h1.ID(), c.RemotePeer()) require.Equal(t, h2.ID(), c.LocalPeer()) }), diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 9936463e23..c8445a3997 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -3,10 +3,16 @@ package transport_integration import ( "bytes" "context" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "errors" "fmt" "io" + "math/big" "net" "runtime" "strings" @@ -15,6 +21,8 @@ import ( "testing" "time" + libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls" + "github.com/libp2p/go-libp2p" "github.com/libp2p/go-libp2p/config" "github.com/libp2p/go-libp2p/core/connmgr" @@ -30,9 +38,9 @@ import ( "github.com/libp2p/go-libp2p/p2p/net/swarm" "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/libp2p/go-libp2p/p2p/security/noise" - tls "github.com/libp2p/go-libp2p/p2p/security/tls" "github.com/libp2p/go-libp2p/p2p/transport/tcp" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" + "github.com/libp2p/go-libp2p/p2p/transport/websocket" "go.uber.org/mock/gomock" ma "github.com/multiformats/go-multiaddr" @@ -68,6 +76,44 @@ func transformOpts(opts TransportTestCaseOpts) []config.Option { return libp2pOpts } +func selfSignedTLSConfig(t *testing.T) *tls.Config { + t.Helper() + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + notBefore := time.Now() + notAfter := notBefore.Add(365 * 24 * time.Hour) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + require.NoError(t, err) + + certTemplate := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Test"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &priv.PublicKey, priv) + require.NoError(t, err) + + cert := tls.Certificate{ + Certificate: [][]byte{derBytes}, + PrivateKey: priv, + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + return tlsConfig +} + var transportsToTest = []TransportTestCase{ { Name: "TCP / Noise / Yamux", @@ -89,7 +135,7 @@ var transportsToTest = []TransportTestCase{ Name: "TCP / TLS / Yamux", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { libp2pOpts := transformOpts(opts) - libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New)) + libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New)) libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport)) if opts.NoListen { libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) @@ -106,7 +152,7 @@ var transportsToTest = []TransportTestCase{ HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { libp2pOpts := transformOpts(opts) libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener()) - libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New)) + libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New)) libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport)) if opts.NoListen { libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) @@ -123,7 +169,7 @@ var transportsToTest = []TransportTestCase{ HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { libp2pOpts := transformOpts(opts) libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener()) - libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New)) + libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New)) libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport)) libp2pOpts = append(libp2pOpts, libp2p.Transport(tcp.NewTCPTransport, tcp.WithMetrics())) if opts.NoListen { @@ -140,7 +186,7 @@ var transportsToTest = []TransportTestCase{ Name: "TCP-WithMetrics / TLS / Yamux", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { libp2pOpts := transformOpts(opts) - libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New)) + libp2pOpts = append(libp2pOpts, libp2p.Security(libp2ptls.ID, libp2ptls.New)) libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport)) libp2pOpts = append(libp2pOpts, libp2p.Transport(tcp.NewTCPTransport, tcp.WithMetrics())) if opts.NoListen { @@ -168,6 +214,23 @@ var transportsToTest = []TransportTestCase{ return h }, }, + { + Name: "WebSocket-Secured-Shared", + HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { + libp2pOpts := transformOpts(opts) + libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener()) + if opts.NoListen { + config := tls.Config{InsecureSkipVerify: true} + libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs, libp2p.Transport(websocket.New, websocket.WithTLSClientConfig(&config))) + } else { + config := selfSignedTLSConfig(t) + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/sni/localhost/tls/ws"), libp2p.Transport(websocket.New, websocket.WithTLSConfig(config))) + } + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + return h + }, + }, { Name: "WebSocket", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { @@ -182,6 +245,22 @@ var transportsToTest = []TransportTestCase{ return h }, }, + { + Name: "WebSocket-Secured", + HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { + libp2pOpts := transformOpts(opts) + if opts.NoListen { + config := tls.Config{InsecureSkipVerify: true} + libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs, libp2p.Transport(websocket.New, websocket.WithTLSClientConfig(&config))) + } else { + config := selfSignedTLSConfig(t) + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/sni/localhost/tls/ws"), libp2p.Transport(websocket.New, websocket.WithTLSConfig(config))) + } + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + return h + }, + }, { Name: "QUIC", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { diff --git a/p2p/transport/tcpreuse/connwithscope.go b/p2p/transport/tcpreuse/connwithscope.go index ca66f20325..bddd3c0f3b 100644 --- a/p2p/transport/tcpreuse/connwithscope.go +++ b/p2p/transport/tcpreuse/connwithscope.go @@ -17,6 +17,11 @@ func (c connWithScope) Scope() network.ConnManagementScope { return c.scope } +func (c *connWithScope) Close() error { + c.scope.Done() + return c.ManetTCPConnInterface.Close() +} + func manetConnWithScope(c manet.Conn, scope network.ConnManagementScope) (manet.Conn, error) { if tcpconn, ok := c.(sampledconn.ManetTCPConnInterface); ok { return &connWithScope{tcpconn, scope}, nil diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go index d94186e7ec..d0affb9fdf 100644 --- a/p2p/transport/tcpreuse/listener.go +++ b/p2p/transport/tcpreuse/listener.go @@ -260,6 +260,7 @@ func (m *multiplexedListener) run() error { select { case demux.buffer <- connWithScope: case <-ctx.Done(): + log.Debug("accept timeout; dropping connection from: %v", connWithScope.RemoteMultiaddr()) connWithScope.Close() } }() diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index ce51611703..1c2ecd03df 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -1,6 +1,7 @@ package websocket import ( + "crypto/tls" "errors" "io" "net" @@ -142,6 +143,13 @@ func (c *Conn) Scope() network.ConnManagementScope { }); ok { return sc.Scope() } + if nc, ok := nc.(*tls.Conn); ok { + if sc, ok := nc.NetConn().(interface { + Scope() network.ConnManagementScope + }); ok { + return sc.Scope() + } + } return nil } diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index dd399aa079..93131a2e07 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -137,9 +137,9 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } select { - case l.incoming <- NewConn(c, l.isWss): + case l.incoming <- nc: case <-l.closed: - c.Close() + nc.Close() } // The connection has been hijacked, it's safe to return. }