From eddab60297cbc06a449a17272d1fc68a8268e4f3 Mon Sep 17 00:00:00 2001 From: Alexei Boronine Date: Sun, 3 Nov 2024 17:19:33 +0100 Subject: [PATCH] Support half-open TCP connections --- README.md | 4 +- src/h2tunnel.test.ts | 277 ++++++++++++++++++++++--------------------- src/h2tunnel.ts | 95 ++++++++------- 3 files changed, 196 insertions(+), 180 deletions(-) diff --git a/README.md b/README.md index fc4d942..2313bd2 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ to the server, and the server proxies requests through this tunnel to your local ## How does h2tunnel work? -1. The client initiates a TLS connection to the server and starts listening for HTTP2 sessions on it +1. The client initiates a TLS connection to the server and starts listening for HTTP2 sessions 2. The server takes the newly created TLS socket and initiates an HTTP2 session through it 3. The server starts accepting TCP connections, converting them into HTTP2 streams, and fowarding them to the client 4. The client receives these HTTP2 streams and converts them back into TCP connections to feed them into the local server @@ -33,7 +33,7 @@ the server, and both are configured to reject anything else. This way, the pair Generate `h2tunnel.key` and `h2tunnel.crt` files using `openssl` command: ```bash -openssl req -x509 -newkey ec -pkeyopt ec_paramgen_curve:secp384r1 -days 3650 -nodes -keyout h2tunnel.key -out h2tunnel.crt -subj "/CN=example.com" +openssl req -x509 -newkey ec -pkeyopt ec_paramgen_curve:secp384r1 -days 3650 -nodes -keyout h2tunnel.key -out h2tunnel.crt -subj "/CN=localhost" ``` ### Forward localhost:8000 to http://example.com diff --git a/src/h2tunnel.test.ts b/src/h2tunnel.test.ts index d29496f..709f1f7 100644 --- a/src/h2tunnel.test.ts +++ b/src/h2tunnel.test.ts @@ -6,8 +6,6 @@ import { TunnelServer, } from "./h2tunnel.js"; import net from "node:net"; -import * as http2 from "node:http2"; -import { strictEqual } from "node:assert"; // localhost HTTP1 server "python3 -m http.server" const LOCAL_PORT = 14000; @@ -23,7 +21,7 @@ const TUNNEL_PORT = 14005; const MUX_PORT = 14006; // Reduce this to make tests faster -const TIME_MULTIPLIER = 0.1; +const TIME_MULTIPLIER = 0.05; const CLIENT_KEY = `-----BEGIN PRIVATE KEY----- MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDCDzcLnOqzvCrnUyd4P @@ -122,7 +120,7 @@ class NetworkEmulator { constructor( readonly originPort: number, readonly proxyPort: number, - readonly server = net.createServer(), + readonly server = net.createServer({ allowHalfOpen: true }), readonly logger = getLogger("network", 31), readonly abortController = new AbortController(), ) {} @@ -134,6 +132,7 @@ class NetworkEmulator { const outgoingSocket = net.createConnection({ host: "127.0.0.1", port: this.originPort, + allowHalfOpen: true, }); this.outgoingSocket = outgoingSocket; outgoingSocket.on("error", () => incomingSocket.resetAndDestroy()); @@ -161,7 +160,7 @@ class EchoServer { public proxyPort = port, readonly logger = getLogger("localhost", 35), ) { - const server = net.createServer(); + const server = net.createServer({ allowHalfOpen: true }); server.on("connection", (socket) => { logger({ echoServer: "connection" }); socket.on("error", (err) => { @@ -172,24 +171,27 @@ class EchoServer { echoServerData: data.toString(), socketWritableEnded: socket.writableEnded, }); - // Add to data received - const previousData = this.dataReceived.get(socket) ?? ""; - this.dataReceived.set(socket, previousData + data.toString("utf-8")); - + this.appendData(socket, data); if (!socket.writableEnded) { socket.write(data); } }); + // Make sure other end stays half-open long enough to receive the last byte + socket.on("end", async () => { + logger({ echoServer: "received FIN" }); + await sleep(50); + logger({ echoServer: "sending last byte" }); + socket.end("z"); + }); }); this.server = server; } - // reset(proxyPort: number) { - // this.proxyPort = proxyPort; - // this.dataReceived.clear(); - // this.i = 0; - // } + appendData(socket: net.Socket, data: Buffer): void { + const previousData = this.dataReceived.get(socket) ?? ""; + this.dataReceived.set(socket, previousData + data.toString("utf-8")); + } getSocketByPrefix(prefix: string): net.Socket { for (const [socket, data] of this.dataReceived) { @@ -215,7 +217,19 @@ class EchoServer { } createClientSocket(): net.Socket { - return net.createConnection(this.proxyPort); + const socket = net.createConnection({ + port: this.proxyPort, + allowHalfOpen: true, + }); + socket.on("data", (chunk) => { + this.appendData(socket, chunk); + }); + // Make sure other end stays half-open long enough to receive the last byte + socket.on("end", async () => { + await sleep(50); + socket.end("z"); + }); + return socket; } async expectEconn() { @@ -268,28 +282,19 @@ async function testConn( term: "FIN" | "RST", by: "client" | "server", delay: number = 0, - strict = true, ) { await sleep(delay); const conn = await server.createConn(t); - await t.test( - `ping pong ${numBytes} byte(s)`, - { plan: numBytes }, - async (t: TestContext) => { - for (let i = 0; i < numBytes; i++) { - await new Promise((resolve) => { - conn.originSocket.once("data", (pong) => { - t.assert.strictEqual(pong.toString(), "a"); - resolve(); - }); - // ping - const ping = "a"; - conn.clientSocket.write(ping); - }); - await sleep(50); - } - }, - ); + for (let i = 0; i < numBytes; i++) { + await new Promise((resolve) => { + conn.originSocket.once("data", (pong) => { + t.assert.strictEqual(pong.toString(), "a"); + resolve(); + }); + conn.clientSocket.write("a"); + }); + await sleep(50); + } const [socket1, socket2] = by === "client" @@ -297,119 +302,115 @@ async function testConn( : [conn.originSocket, conn.clientSocket]; if (term === "FIN") { - await t.test( - `clean termination by ${by} FIN`, - { plan: 12, timeout: 1000 }, - (t: TestContext) => - new Promise((resolve, reject) => { - let i = 0; - const done = () => i === 2 && resolve(); - t.assert.strictEqual(socket2.readyState, "open"); - t.assert.strictEqual(socket1.readyState, "open"); - socket2.on("end", () => { - // Server sent FIN and client received it - t.assert.strictEqual(socket2.readyState, "writeOnly"); - t.assert.strictEqual( - socket1.readyState, - strict ? "readOnly" : "closed", - ); - }); - socket2.on("close", (hasError) => { - t.assert.strictEqual(hasError, false); - t.assert.strictEqual(socket2.errored, null); - t.assert.strictEqual(socket2.readyState, "closed"); - i++; - done(); - }); - socket1.on("close", (hasError) => { - t.assert.strictEqual(hasError, false); - t.assert.strictEqual(socket1.errored, null); - t.assert.strictEqual(socket1.readyState, "closed"); - i++; - done(); - }); - socket1.end(); - // Server sent FIN, but client didn't receive it yet - t.assert.strictEqual(socket2.readyState, "open"); + const promise = Promise.all([ + new Promise((resolve) => { + socket2.on("end", () => { + // socket1 sent FIN and socket2 received it + t.assert.strictEqual(socket2.readyState, "writeOnly"); t.assert.strictEqual(socket1.readyState, "readOnly"); - }), - ); + resolve(); + }); + }), + new Promise((resolve) => { + socket2.on("close", (hasError) => { + t.assert.strictEqual(hasError, false); + t.assert.strictEqual(socket2.errored, null); + t.assert.strictEqual(socket2.readyState, "closed"); + resolve(); + }); + }), + new Promise((resolve) => { + socket1.on("close", (hasError) => { + t.assert.strictEqual(hasError, false); + t.assert.strictEqual(socket1.errored, null); + t.assert.strictEqual(socket1.readyState, "closed"); + resolve(); + }); + }), + ]); + t.assert.strictEqual(socket2.readyState, "open"); + t.assert.strictEqual(socket1.readyState, "open"); + socket1.end(); + // socket1 sent FIN, but socket2 didn't receive it yet + t.assert.strictEqual(socket2.readyState, "open"); + t.assert.strictEqual(socket1.readyState, "readOnly"); + await promise; + const socket1data = server.dataReceived.get(socket1); + const socket2data = server.dataReceived.get(socket2); + // Make sure last byte was successfully communicated in half-open state + t.assert.strictEqual(socket1data, socket2data + "z"); } else if (term == "RST") { - await t.test( - `clean reset by ${by} RST`, - { plan: 8, timeout: 1000 }, - (t: TestContext) => - new Promise((resolve) => { - let i = 0; - const done = () => i === 2 && resolve(); - socket2.on("error", (err) => { - t.assert.strictEqual(err["code"], "ECONNRESET"); - t.assert.strictEqual(socket2.readyState, "closed"); - t.assert.strictEqual(socket2.destroyed, true); - i++; - done(); - }); - socket1.on("close", (hasError) => { - // No error on our end because we initiated the RST - t.assert.strictEqual(hasError, false); - t.assert.strictEqual(socket1.readyState, "closed"); - t.assert.strictEqual(socket1.destroyed, true); - i++; - done(); - }); - socket1.resetAndDestroy(); + socket1.resetAndDestroy(); + t.assert.strictEqual(socket1.readyState, "closed"); + t.assert.strictEqual(socket2.readyState, "open"); + await Promise.all([ + new Promise((resolve) => { + socket2.on("error", (err) => { + t.assert.strictEqual(err["code"], "ECONNRESET"); + t.assert.strictEqual(socket2.readyState, "closed"); + t.assert.strictEqual(socket2.destroyed, true); + resolve(); + }); + }), + new Promise((resolve) => { + socket1.on("close", (hasError) => { + // No error on our end because we initiated the RST + t.assert.strictEqual(hasError, false); t.assert.strictEqual(socket1.readyState, "closed"); - t.assert.strictEqual(socket2.readyState, "open"); - }), - ); + t.assert.strictEqual(socket1.destroyed, true); + resolve(); + }); + }), + ]); } } -await test.only("basic connection and termination", async (t) => { - const net = new NetworkEmulator(LOCAL_PORT, PROXY_TEST_PORT); - const server = new TunnelServer(serverOptions); - const client = new TunnelClient(clientOptions); - server.start(); - client.start(); - await server.waitUntilListening(); - await client.waitUntilConnected(); - await server.waitUntilConnected(); - console.log(0, client.state); - await net.startAndWaitUntilReady(); - for (const term of ["FIN", "RST"] satisfies ("FIN" | "RST")[]) { - for (const by of ["client", "server"] satisfies ("client" | "server")[]) { - for (const proxyPort of [LOCAL_PORT, PROXY_TEST_PORT, PROXY_PORT]) { - await t.test( - `clean termination by ${by} ${term} on ${proxyPort}`, - async (t) => { - const echoServer = new EchoServer(LOCAL_PORT, proxyPort); - await echoServer.startAndWaitUntilReady(); - const strict = proxyPort !== PROXY_PORT; - // Test single - await testConn(t, echoServer, 1, term, by, 0, strict); - await testConn(t, echoServer, 4, term, by, 0, strict); - // Test double simultaneous - await Promise.all([ - testConn(t, echoServer, 3, term, by, 0, strict), - testConn(t, echoServer, 3, term, by, 0, strict), - ]); - // Test triple delayed - await Promise.all([ - testConn(t, echoServer, 4, term, by, 0, strict), - testConn(t, echoServer, 4, term, by, 10, strict), - testConn(t, echoServer, 4, term, by, 100, strict), - ]); - await echoServer.stopAndWaitUntilClosed(); - }, - ); +await test.only( + "basic connection and termination", + { timeout: 2000 }, + async (t) => { + const net = new NetworkEmulator(LOCAL_PORT, PROXY_TEST_PORT); + const server = new TunnelServer(serverOptions); + const client = new TunnelClient(clientOptions); + server.start(); + client.start(); + await server.waitUntilListening(); + await client.waitUntilConnected(); + await server.waitUntilConnected(); + await net.startAndWaitUntilReady(); + for (const term of ["FIN", "RST"] satisfies ("FIN" | "RST")[]) { + for (const by of ["client", "server"] satisfies ("client" | "server")[]) { + for (const proxyPort of [LOCAL_PORT, PROXY_TEST_PORT, PROXY_PORT]) { + // if (term !== "FIN" || by !== "client") { + // continue; + // } + console.log(`clean termination by ${by} ${term} on ${proxyPort}`); + const echoServer = new EchoServer(LOCAL_PORT, proxyPort); + await echoServer.startAndWaitUntilReady(); + // Test single + await testConn(t, echoServer, 1, term, by, 0); + await testConn(t, echoServer, 4, term, by, 0); + // Test double simultaneous + await Promise.all([ + testConn(t, echoServer, 3, term, by, 0), + testConn(t, echoServer, 3, term, by, 0), + ]); + // Test triple delayed + await Promise.all([ + testConn(t, echoServer, 4, term, by, 0), + testConn(t, echoServer, 4, term, by, 10), + testConn(t, echoServer, 4, term, by, 100), + ]); + await echoServer.stopAndWaitUntilClosed(); + } } } - } - await net.stopAndWaitUntilClosed(); - await client.stop(); - await server.stop(); -}); + await net.stopAndWaitUntilClosed(); + await client.stop(); + await server.stop(); + }, +); await test("happy-path", async (t) => { const echo = new EchoServer(LOCAL_PORT, PROXY_PORT); diff --git a/src/h2tunnel.ts b/src/h2tunnel.ts index 2f253f4..e1193c6 100644 --- a/src/h2tunnel.ts +++ b/src/h2tunnel.ts @@ -2,6 +2,7 @@ import net from "node:net"; import events from "node:events"; import * as http2 from "node:http2"; import * as tls from "node:tls"; +import stream from "node:stream"; export type TunnelState = "listening" | "connected" | "stopped" | "stopping"; @@ -67,8 +68,6 @@ export abstract class AbstractTunnel< this.muxSocket?.destroy(); this.tunnelSocket?.destroy(); this.h2session?.destroy(); - // Session does not get destroyed fast enough, we can have a situation where tunnel is closed but session remains - this.h2session = null; } setH2Session(session: S) { @@ -140,49 +139,64 @@ export abstract class AbstractTunnel< } addDemuxSocket(socket: net.Socket, stream: http2.Http2Stream): void { - this.log({ demuxSocket: "added", streamId: stream.id }); - socket.on("data", (chunk) => { - this.log({ streamDataWrite: chunk.length, streamId: stream.id }); - stream.write(chunk); - }); - stream.on("data", (chunk) => { - this.log({ streamDataRead: chunk.length, streamId: stream.id }); - socket.write(chunk); - }); - // Prevent error being logged, we are handling it during the "close" event - socket.on("error", () => {}); - socket.on("close", () => { + const log = (line: object) => { this.log({ - demuxSocket: "close", streamId: stream.id, + streamWritableEnd: stream.writableEnded, + socketWritableEnd: socket.writableEnded, + streamDestroyed: stream.destroyed, + socketDestroyed: socket.destroyed, streamError: stream.errored, socketError: socket.errored, + ...line, }); - if (!stream.destroyed) { - if (socket.errored) { - stream.close(http2.constants.NGHTTP2_INTERNAL_ERROR); - } else { - stream.destroy(); - } - } - }); - // Prevent error being logged, we are handling it during the "close" event - stream.on("error", () => {}); - stream.on("close", () => { - this.log({ - demuxStream: "close", - streamId: stream.id, - streamError: stream.errored, - socketError: socket.errored, + }; + log({ demux: "added" }); + + const setup = (duplex1: stream.Duplex, duplex2: stream.Duplex) => { + const isStream = duplex1 === stream; + const tag = isStream ? "demuxStream" : "demuxSocket"; + duplex1.on("data", (chunk) => { + log({ + [isStream ? "readBytes" : "writeBytes"]: chunk.length, + }); + duplex2.write(chunk); + }); + // Catch error but do not handle it, we will handle it later during the 'close' event + duplex1.on("error", () => { + log({ [tag]: "error" }); }); - if (!socket.destroyed) { - if (stream.errored) { - socket.resetAndDestroy(); - } else { - socket.destroy(); + let endTimeout: NodeJS.Timeout | null = null; + duplex1.on("end", () => { + log({ [tag]: "end" }); + // 'end' comes before 'error', so we need to wait make sure 'error' doesn't come it before processing 'end' + // https://github.com/nodejs/node/issues/39400 + endTimeout = setTimeout(() => { + if (!duplex2.writableEnded) { + log({ [tag]: "closing opposite" }); + duplex2.end(); + } + }, 30); + }); + + duplex1.on("close", () => { + log({ [tag]: "close" }); + if (duplex1.errored && !duplex2.destroyed) { + if (endTimeout) { + clearTimeout(endTimeout); + } + log({ [tag]: "destroying opposite" }); + if (isStream) { + socket.resetAndDestroy(); + } else { + stream.close(http2.constants.NGHTTP2_INTERNAL_ERROR); + } } - } - }); + }); + }; + + setup(socket, stream); + setup(stream, socket); } start() { @@ -244,7 +258,7 @@ export class TunnelServer extends AbstractTunnel< // This is necessary only if the client uses a self-signed certificate. ca: [options.cert], }), - readonly proxyServer = net.createServer(), + readonly proxyServer = net.createServer({ allowHalfOpen: true }), ) { super(options.logger, net.createServer(), options.muxListenPort); this.muxServer.on("connection", (socket: net.Socket) => { @@ -258,7 +272,7 @@ export class TunnelServer extends AbstractTunnel< } else { this.addDemuxSocket( socket, - this.h2session!.request({ + this.h2session.request({ [http2.constants.HTTP2_HEADER_METHOD]: "POST", }), ); @@ -368,6 +382,7 @@ export class TunnelClient extends AbstractTunnel< net.createConnection({ host: "127.0.0.1", port: this.options.localHttpPort, + allowHalfOpen: true, }), stream, );