From c29d41c029a27e9d93d183abf1e805196bc1f6f2 Mon Sep 17 00:00:00 2001 From: yanglihz <798862514@qq.com> Date: Tue, 28 Jun 2022 10:23:33 +0800 Subject: [PATCH] optimizing NIO processing with reactor thread model --- .../org/java_websocket/AbstractWebSocket.java | 2 +- .../org/java_websocket/WebSocketImpl.java | 122 +++++++++++++++++- .../org/java_websocket/reactor/Acceptor.java | 111 ++++++++++++++++ .../java_websocket/reactor/TCPReactor.java | 79 ++++++++++++ .../java_websocket/reactor/TCPSubReactor.java | 63 +++++++++ .../server/WebSocketServer.java | 117 ++++++----------- 6 files changed, 414 insertions(+), 80 deletions(-) create mode 100644 src/main/java/org/java_websocket/reactor/Acceptor.java create mode 100644 src/main/java/org/java_websocket/reactor/TCPReactor.java create mode 100644 src/main/java/org/java_websocket/reactor/TCPSubReactor.java diff --git a/src/main/java/org/java_websocket/AbstractWebSocket.java b/src/main/java/org/java_websocket/AbstractWebSocket.java index c3e77a089..fb6422f34 100644 --- a/src/main/java/org/java_websocket/AbstractWebSocket.java +++ b/src/main/java/org/java_websocket/AbstractWebSocket.java @@ -162,7 +162,7 @@ protected void stopConnectionLostTimer() { * * @since 1.3.4 */ - protected void startConnectionLostTimer() { + public void startConnectionLostTimer() { synchronized (syncConnectionLost) { if (this.connectionLostTimeout <= 0) { log.trace("Connection lost timer deactivated"); diff --git a/src/main/java/org/java_websocket/WebSocketImpl.java b/src/main/java/org/java_websocket/WebSocketImpl.java index aad172127..d6b2ad37a 100644 --- a/src/main/java/org/java_websocket/WebSocketImpl.java +++ b/src/main/java/org/java_websocket/WebSocketImpl.java @@ -29,6 +29,7 @@ import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.ByteChannel; +import java.nio.channels.SelectableChannel; import java.nio.channels.SelectionKey; import java.util.ArrayList; import java.util.Collection; @@ -59,6 +60,7 @@ import org.java_websocket.handshake.ServerHandshakeBuilder; import org.java_websocket.interfaces.ISSLChannel; import org.java_websocket.protocols.IProtocol; +import org.java_websocket.server.WebSocketServer; import org.java_websocket.server.WebSocketServer.WebSocketWorker; import org.java_websocket.util.Charsetfunctions; import org.slf4j.Logger; @@ -69,7 +71,7 @@ * "handshake" phase, then allows for easy sending of text frames, and receiving frames through an * event-based model. */ -public class WebSocketImpl implements WebSocket { +public class WebSocketImpl implements WebSocket, Runnable { /** * The default port of WebSockets, as defined in the spec. If the nullary constructor is used, @@ -914,6 +916,120 @@ public WebSocketWorker getWorkerThread() { public void setWorkerThread(WebSocketWorker workerThread) { this.workerThread = workerThread; } - - + + @Override + public void run() { + if( key.isValid() ) { + WebSocketImpl conn = (WebSocketImpl) key.attachment(); + if (key.isReadable()) { + ByteBuffer buf = null; + try { + synchronized (this.wsl) { + buf = ((WebSocketServer) wsl).takeBuffer(); + } + } catch (InterruptedException e) { + log.error(e.getMessage(), e); + } + if(conn.getChannel() == null){ + key.cancel(); + handleIOException( key, conn, new IOException() ); + } + try { + if( SocketChannelIOHelper.read(buf, conn, conn.getChannel())) { + if(buf.hasRemaining()) { + conn.inQueue.put( buf ); + synchronized (this.wsl) { + ((WebSocketServer) wsl).queue( conn ); + } + if(conn.getChannel() instanceof WrappedByteChannel && ((WrappedByteChannel) conn.getChannel()).isNeedRead()) { + synchronized (this.wsl) { + ((WebSocketServer) wsl).getIqueue().add( conn ); + } + } + } else { + synchronized (this.wsl) { + ((WebSocketServer) wsl).pushBuffer(buf); + } + } + } else { + synchronized (this.wsl) { + ((WebSocketServer) wsl).pushBuffer( buf ); + } + } + } catch ( IOException e ) { + try { + synchronized (this.wsl) { + ((WebSocketServer) wsl).pushBuffer( buf ); + } + } catch (InterruptedException e1) { + log.error(e1.getMessage(), e1); + } + } catch (InterruptedException e) { + log.error(e.getMessage(), e); + } + } else if(key.isWritable()) { + try { + if(SocketChannelIOHelper.batch(conn, conn.getChannel()) && key.isValid()) { + key.interestOps(SelectionKey.OP_READ); + } + } catch (IOException e) { + log.error(e.getMessage(), e); + } + } + try { + synchronized (this.wsl) { + doAdditionalRead(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (IOException e) { + if( key != null ) + key.cancel(); + handleIOException( key, conn, e ); + } + } + } + + private void handleIOException(SelectionKey key, WebSocket conn, IOException ex) { + // onWebsocketError( conn, ex );// conn may be null here + if (key != null) { + key.cancel(); + } + if (conn != null) { + conn.closeConnection(CloseFrame.ABNORMAL_CLOSE, ex.getMessage()); + } else if (key != null) { + SelectableChannel channel = key.channel(); + if (channel != null && channel + .isOpen()) { // this could be the case if the IOException ex is a SSLException + try { + channel.close(); + } catch (IOException e) { + // there is nothing that must be done here + } + log.trace("Connection closed because of exception", ex); + } + } + } + + private void doAdditionalRead() throws InterruptedException, IOException { + WebSocketImpl conn; + while ( !((WebSocketServer) wsl).getIqueue().isEmpty() ) { + conn = ((WebSocketServer) wsl).getIqueue().remove( 0 ); + WrappedByteChannel c = ( (WrappedByteChannel) conn.getChannel() ); + ByteBuffer buf = ((WebSocketServer) wsl).takeBuffer(); + try { + if( SocketChannelIOHelper.readMore( buf, conn, c ) ) + ((WebSocketServer) wsl).getIqueue().add( conn ); + if( buf.hasRemaining() ) { + conn.inQueue.put( buf ); + ((WebSocketServer) wsl).queue( conn ); + } else { + ((WebSocketServer) wsl).pushBuffer( buf ); + } + } catch ( IOException e ) { + ((WebSocketServer) wsl).pushBuffer( buf ); + throw e; + } + } + } } diff --git a/src/main/java/org/java_websocket/reactor/Acceptor.java b/src/main/java/org/java_websocket/reactor/Acceptor.java new file mode 100644 index 000000000..712535cc2 --- /dev/null +++ b/src/main/java/org/java_websocket/reactor/Acceptor.java @@ -0,0 +1,111 @@ +package org.java_websocket.reactor; + +import java.io.IOException; +import java.net.Socket; +import java.nio.ByteBuffer; +import java.nio.channels.SelectableChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; + +import org.java_websocket.WebSocket; +import org.java_websocket.WebSocketImpl; +import org.java_websocket.framing.CloseFrame; +import org.java_websocket.server.WebSocketServer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * acceptor process new client connections and dispatch requests to the processor chain + */ +public class Acceptor implements Runnable { + private static final Logger log = LoggerFactory.getLogger(Acceptor.class); + private final ServerSocketChannel ssc; // socket channel monitored by mainReactor + private final WebSocketServer wss; + private final int cores = Runtime.getRuntime().availableProcessors(); // get the number of CPU cores + private final Selector[] selectors = new Selector[cores]; // create several core selectors for subReactor + private int selIdx = 0; // currently available subreactor indexes + private TCPSubReactor[] r = new TCPSubReactor[cores]; // subReactor thread + private Thread[] t = new Thread[cores]; // subReactor thread + + public Acceptor(ServerSocketChannel s, WebSocketServer server) throws IOException { + this.ssc = s; + this.wss = server; + // Create multiple selectors and multiple subReactor threads + for (int i = 0; i < cores; i++) { + selectors[i] = Selector.open(); + r[i] = new TCPSubReactor(selectors[i], s, i); + t[i] = new Thread(r[i]); + t[i].start(); + } + } + + @Override + public void run() { + try { + SocketChannel sc = ssc.accept(); // receive client connection request + if (sc != null) { + log.trace(sc.socket().getRemoteSocketAddress().toString() + " is connected."); + log.trace("selIdx is {}", selIdx); + sc.configureBlocking(false); // set non blocking + Socket socket = sc.socket(); + socket.setTcpNoDelay(this.wss.isTcpNoDelay()); + socket.setKeepAlive(true); + r[selIdx].setRestart(true); // pause thread + selectors[selIdx].wakeup(); // causes a blocked selector operation to return immediately + WebSocketImpl w = this.wss.getWsf().createWebSocket(this.wss, this.wss.getDrafts()); + SelectionKey sk = sc.register(selectors[selIdx], SelectionKey.OP_READ, w); + w.setSelectionKey(sk); + try { + w.setChannel(this.wss.getWsf().wrapChannel(sc, w.getSelectionKey())); + allocateBuffers(w); + } catch (IOException ex) { + log.error(ex.getMessage(), ex); + if (w.getSelectionKey() != null) + w.getSelectionKey().cancel(); + handleIOException(w.getSelectionKey(), null, ex); + } catch (InterruptedException e) { + log.error(e.getMessage(), e); + } + selectors[selIdx].wakeup(); // causes a blocked selector operation to return immediately + r[selIdx].setRestart(false); // restart thread + if (++selIdx == selectors.length) + selIdx = 0; + } + } catch (IOException e) { + log.error(e.getMessage(), e); + } + } + + protected void allocateBuffers(WebSocket c) throws InterruptedException { + synchronized (this.wss) { + if (this.wss.getQueuesize().get() >= 2 * this.wss.getDecoders().size() + 1) { + return; + } + this.wss.getQueuesize().incrementAndGet(); + this.wss.getBuffers().put(createBuffer()); + } + } + + public ByteBuffer createBuffer() { + return ByteBuffer.allocate(WebSocketImpl.RCVBUF); + } + + private void handleIOException(SelectionKey key, WebSocket conn, IOException ex) { + // onWebsocketError( conn, ex );// conn may be null here + if (conn != null) { + conn.closeConnection(CloseFrame.ABNORMAL_CLOSE, ex.getMessage()); + } else if (key != null) { + SelectableChannel channel = key.channel(); + if (channel != null && channel.isOpen()) { // this could be the case if the IOException ex is a SSLException + try { + channel.close(); + } catch (IOException e) { + // there is nothing that must be done here + } + log.trace("Connection closed because of exception", ex); + } + } + } +} diff --git a/src/main/java/org/java_websocket/reactor/TCPReactor.java b/src/main/java/org/java_websocket/reactor/TCPReactor.java new file mode 100644 index 000000000..42c576f84 --- /dev/null +++ b/src/main/java/org/java_websocket/reactor/TCPReactor.java @@ -0,0 +1,79 @@ +package org.java_websocket.reactor; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.ClosedSelectorException; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.util.Iterator; +import java.util.Set; + +import org.java_websocket.server.WebSocketServer; +import org.java_websocket.server.WebSocketServer.WebSocketWorker; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * TCPReactor is the mainReactor, process new client connections + */ +public class TCPReactor implements Runnable { + private static final Logger log = LoggerFactory.getLogger(TCPReactor.class); + + private ServerSocketChannel ssc; + private Selector selector; + + public TCPReactor(InetSocketAddress address, ServerSocketChannel server, WebSocketServer wss, + Selector s) { + try { + this.ssc = server; + this.selector = s; + Acceptor acceptor = new Acceptor(server, wss); + SelectionKey sk = server.register(selector, SelectionKey.OP_ACCEPT); + sk.attach(acceptor); + wss.startConnectionLostTimer(); + for (WebSocketWorker ex : wss.getDecoders()) { + ex.start(); + } + wss.onStart(); + } catch (SocketException e) { + log.error(e.getMessage(), e); + } catch (ClosedChannelException e) { + log.error(e.getMessage(), e); + } catch (IOException e) { + log.error(e.getMessage(), e); + } + } + + @Override + public void run() { + while (!Thread.interrupted()) { + log.trace("mainReactor waiting for new event on port: " + ssc.socket().getLocalPort() + + "..."); + try { + if (selector.select() == 0) {// if no event is ready, do not proceed + continue; + } + Set selectedKeys = selector.selectedKeys(); + Iterator it = selectedKeys.iterator(); + while (it.hasNext()) { + dispatch((SelectionKey) (it.next())); + it.remove(); + } + } catch (IOException e) { + log.error(e.getMessage(), e); + } catch (ClosedSelectorException e) { + log.error(e.getMessage(), e); + } + } + } + + private void dispatch(SelectionKey key) { + Runnable r = (Runnable) (key.attachment()); + if (r != null) + r.run(); + } + +} diff --git a/src/main/java/org/java_websocket/reactor/TCPSubReactor.java b/src/main/java/org/java_websocket/reactor/TCPSubReactor.java new file mode 100644 index 000000000..cdf5e896b --- /dev/null +++ b/src/main/java/org/java_websocket/reactor/TCPSubReactor.java @@ -0,0 +1,63 @@ +package org.java_websocket.reactor; + +import java.io.IOException; +import java.nio.channels.ClosedSelectorException; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.util.Iterator; +import java.util.Set; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * TCPSubReactor is the subReactor, read/write network data and perform business processing, and throw it to the worker thread pool + */ +public class TCPSubReactor implements Runnable { + private static final Logger log = LoggerFactory.getLogger(TCPSubReactor.class); + private final Selector selector; + private volatile boolean restart = false; + int num; + + public TCPSubReactor(Selector selector, ServerSocketChannel ssc, int num) { + this.selector = selector; + this.num = num; + } + + @Override + public void run() { + while (!Thread.interrupted()) { // continue running until the thread is interrupted + log.trace("waiting for restart"); + while (!Thread.interrupted() && !restart) { // runs continuously until the thread is interrupted and designated to restart + try { + if (selector.select() == 0) {// ff no event is ready, do not proceed + continue; + } + Set selectedKeys = selector.selectedKeys(); // get the key collection of all ready events + Iterator it = selectedKeys.iterator(); + while (it.hasNext()) { + dispatch((SelectionKey) (it.next())); // schedule according to the key of the event + it.remove(); + } + } catch (IOException e) { + log.error(e.getMessage(), e); + } catch (ClosedSelectorException e) { + log.error(e.getMessage(), e); + } catch (Throwable t) { + log.error(t.getMessage(), t); + } + } + } + } + + private void dispatch(SelectionKey key) { + Runnable r = (Runnable) (key.attachment()); // open a new thread according to the object bound by the key of the event + if (r != null) + r.run(); + } + + public void setRestart(boolean restart) { + this.restart = restart; + } +} diff --git a/src/main/java/org/java_websocket/server/WebSocketServer.java b/src/main/java/org/java_websocket/server/WebSocketServer.java index bb8178c25..e65632eb2 100644 --- a/src/main/java/org/java_websocket/server/WebSocketServer.java +++ b/src/main/java/org/java_websocket/server/WebSocketServer.java @@ -32,7 +32,6 @@ import java.net.SocketAddress; import java.nio.ByteBuffer; import java.nio.channels.CancelledKeyException; -import java.nio.channels.ClosedByInterruptException; import java.nio.channels.SelectableChannel; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; @@ -47,12 +46,12 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; + import org.java_websocket.AbstractWebSocket; import org.java_websocket.SocketChannelIOHelper; import org.java_websocket.WebSocket; @@ -67,6 +66,7 @@ import org.java_websocket.framing.Framedata; import org.java_websocket.handshake.ClientHandshake; import org.java_websocket.handshake.Handshakedata; +import org.java_websocket.reactor.TCPReactor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -367,62 +367,6 @@ public void run() { if (!doSetupSelectorAndServerThread()) { return; } - try { - int shutdownCount = 5; - int selectTimeout = 0; - while (!selectorthread.isInterrupted() && shutdownCount != 0) { - SelectionKey key = null; - try { - if (isclosed.get()) { - selectTimeout = 5; - } - int keyCount = selector.select(selectTimeout); - if (keyCount == 0 && isclosed.get()) { - shutdownCount--; - } - Set keys = selector.selectedKeys(); - Iterator i = keys.iterator(); - - while (i.hasNext()) { - key = i.next(); - - if (!key.isValid()) { - continue; - } - - if (key.isAcceptable()) { - doAccept(key, i); - continue; - } - - if (key.isReadable() && !doRead(key, i)) { - continue; - } - - if (key.isWritable()) { - doWrite(key); - } - } - doAdditionalRead(); - } catch (CancelledKeyException e) { - // an other thread may cancel the key - } catch (ClosedByInterruptException e) { - return; // do the same stuff as when InterruptedException is thrown - } catch (WrappedIOException ex) { - handleIOException(key, ex.getConnection(), ex.getIOException()); - } catch (IOException ex) { - handleIOException(key, null, ex); - } catch (InterruptedException e) { - // FIXME controlled shutdown (e.g. take care of buffermanagement) - Thread.currentThread().interrupt(); - } - } - } catch (RuntimeException e) { - // should hopefully never occur - handleFatal(null, e); - } finally { - doServerShutdown(); - } } /** @@ -559,20 +503,17 @@ private void doWrite(SelectionKey key) throws WrappedIOException { private boolean doSetupSelectorAndServerThread() { selectorthread.setName("WebSocketSelector-" + selectorthread.getId()); try { - server = ServerSocketChannel.open(); - server.configureBlocking(false); - ServerSocket socket = server.socket(); - socket.setReceiveBufferSize(WebSocketImpl.RCVBUF); - socket.setReuseAddress(isReuseAddr()); - socket.bind(address, getMaxPendingConnections()); - selector = Selector.open(); - server.register(selector, server.validOps()); - startConnectionLostTimer(); - for (WebSocketWorker ex : decoders) { - ex.start(); - } - onStart(); - } catch (IOException ex) { + server = ServerSocketChannel.open(); + server.configureBlocking(false); + ServerSocket socket = server.socket(); + socket.setReceiveBufferSize(WebSocketImpl.RCVBUF); + socket.setReuseAddress(this.isReuseAddr()); + socket.bind(address, this.getMaxPendingConnections()); + selector = Selector.open(); + TCPReactor reactor = new TCPReactor(address, server, this, selector); + Thread thread = new Thread(reactor); + thread.start(); + } catch (Exception ex) { handleFatal(null, ex); return false; } @@ -642,7 +583,7 @@ public ByteBuffer createBuffer() { return ByteBuffer.allocate(WebSocketImpl.RCVBUF); } - protected void queue(WebSocketImpl ws) throws InterruptedException { + public void queue(WebSocketImpl ws) throws InterruptedException { if (ws.getWorkerThread() == null) { ws.setWorkerThread(decoders.get(queueinvokes % decoders.size())); queueinvokes++; @@ -650,11 +591,11 @@ protected void queue(WebSocketImpl ws) throws InterruptedException { ws.getWorkerThread().put(ws); } - private ByteBuffer takeBuffer() throws InterruptedException { + public ByteBuffer takeBuffer() throws InterruptedException { return buffers.take(); } - private void pushBuffer(ByteBuffer buf) throws InterruptedException { + public void pushBuffer(ByteBuffer buf) throws InterruptedException { if (buffers.size() > queuesize.intValue()) { return; } @@ -802,7 +743,7 @@ public final void onWriteDemand(WebSocket w) { // the thread which cancels key is responsible for possible cleanup conn.outQueue.clear(); } - selector.wakeup(); + conn.getSelectionKey().selector().wakeup(); } @Override @@ -1119,4 +1060,28 @@ private void doDecode(WebSocketImpl ws, ByteBuffer buf) throws InterruptedExcept } } } + + public WebSocketServerFactory getWsf() { + return wsf; + } + + public List getIqueue() { + return iqueue; + } + + public List getDecoders() { + return decoders; + } + + public List getDrafts() { + return drafts; + } + + public BlockingQueue getBuffers() { + return buffers; + } + + public AtomicInteger getQueuesize() { + return queuesize; + } }