diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransport.java new file mode 100644 index 00000000..b6d8e991 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransport.java @@ -0,0 +1,11 @@ +package io.modelcontextprotocol.client.transport; + +import java.net.UnixDomainSocketAddress; + +import io.modelcontextprotocol.spec.McpClientTransport; + +public interface UdsMcpClientTransport extends McpClientTransport { + + UnixDomainSocketAddress getUdsAddress(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransportImpl.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransportImpl.java new file mode 100644 index 00000000..6599d69a --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UdsMcpClientTransportImpl.java @@ -0,0 +1,248 @@ +package io.modelcontextprotocol.client.transport; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.channels.SelectionKey; +import java.time.Duration; +import java.util.concurrent.Executors; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.UDSClientSocketChannel; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class UdsMcpClientTransportImpl implements UdsMcpClientTransport { + + private static final Logger logger = LoggerFactory.getLogger(UdsMcpClientTransportImpl.class); + + private final Sinks.Many inboundSink; + + private final Sinks.Many outboundSink; + + private ObjectMapper objectMapper; + + /** Scheduler for handling outbound messages to the server process */ + private Scheduler outboundScheduler; + + private final Sinks.Many errorSink; + + private UDSClientSocketChannel clientChannel; + + private UnixDomainSocketAddress targetAddress; + + private volatile boolean isClosing = false; + + // visible for tests + private Consumer stdErrorHandler = error -> logger.info("STDERR Message received: {}", error); + + public UnixDomainSocketAddress getUdsAddress() { + return this.targetAddress; + } + + public UdsMcpClientTransportImpl(UnixDomainSocketAddress targetAddress) { + this(new ObjectMapper(), targetAddress); + } + + public UdsMcpClientTransportImpl(ObjectMapper objectMapper, UnixDomainSocketAddress targetAddress) { + Assert.notNull(objectMapper, "objectMapper cannot be null"); + this.objectMapper = objectMapper; + Assert.notNull(objectMapper, "targetAddress cannot be null"); + this.targetAddress = targetAddress; + this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.errorSink = Sinks.many().unicast().onBackpressureBuffer(); + try { + this.clientChannel = new UDSClientSocketChannel() { + @Override + protected void handleException(SelectionKey key, Throwable e) { + isClosing = true; + super.handleException(key, e); + } + }; + } + catch (IOException e) { + throw new RuntimeException(e); + } + this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound"); + } + + /** + * Starts the server process and initializes the message processing streams. This + * method sets up the process with the configured command, arguments, and environment, + * then starts the inbound, outbound, and error processing threads. + * @throws RuntimeException if the process fails to start or if the process streams + * are null + */ + @Override + public Mono connect(Function, Mono> handler) { + return Mono.fromRunnable(() -> { + handleIncomingMessages(handler); + handleIncomingErrors(); + + // Connect client channel + try { + this.clientChannel.connect(targetAddress, (client) -> { + if (logger.isInfoEnabled()) { + logger.info("UdsMcpClientTransportImpl CONNECTED to targetAddress=" + targetAddress); + } + }, (message) -> { + if (logger.isDebugEnabled()) { + logger.debug("received message=" + message); + } + // Incoming messages processed right here + McpSchema.JSONRPCMessage jsonMessage = McpSchema.deserializeJsonRpcMessage(objectMapper, message); + if (!this.inboundSink.tryEmitNext(jsonMessage).isSuccess()) { + if (!isClosing) { + if (logger.isDebugEnabled()) { + logger.error("Failed to enqueue inbound json message: {}", jsonMessage); + } + } + } + }); + } + catch (IOException e) { + this.clientChannel.close(); + throw new RuntimeException( + "Connect to address=" + targetAddress + " failed message: " + e.getMessage()); + } + + startOutboundProcessing(); + + }).subscribeOn(Schedulers.boundedElastic()); + } + + /** + * Sets the handler for processing transport-level errors. + * + *

+ * The provided handler will be called when errors occur during transport operations, + * such as connection failures or protocol violations. + *

+ * @param errorHandler a consumer that processes error messages + */ + public void setStdErrorHandler(Consumer errorHandler) { + this.stdErrorHandler = errorHandler; + } + + private void handleIncomingMessages(Function, Mono> inboundMessageHandler) { + this.inboundSink.asFlux() + .flatMap(message -> Mono.just(message) + .transform(inboundMessageHandler) + .contextWrite(ctx -> ctx.put("observation", "myObservation"))) + .subscribe(); + } + + private void handleIncomingErrors() { + this.errorSink.asFlux().subscribe(e -> { + this.stdErrorHandler.accept(e); + }); + } + + @Override + public Mono sendMessage(JSONRPCMessage message) { + outboundSink.emitNext(message, (signalType, emitResult) -> { + // Allow retry + return true; + }); + return Mono.empty(); + } + + /** + * Starts the outbound processing thread that writes JSON-RPC messages to the + * process's output stream. Messages are serialized to JSON and written with a newline + * delimiter. + */ + private void startOutboundProcessing() { + this.handleOutbound(messages -> messages + // this bit is important since writes come from user threads, and we + // want to ensure that the actual writing happens on a dedicated thread + .publishOn(outboundScheduler) + .handle((message, sink) -> { + if (message != null && !isClosing) { + try { + clientChannel.writeMessage(objectMapper.writeValueAsString(message)); + sink.next(message); + } + catch (IOException e) { + if (!isClosing) { + logger.error("Error writing message", e); + sink.error(new RuntimeException(e)); + } + else { + logger.debug("Stream closed during shutdown", e); + } + } + } + })); + } + + protected void handleOutbound(Function, Flux> outboundConsumer) { + outboundConsumer.apply(outboundSink.asFlux()).doOnComplete(() -> { + isClosing = true; + outboundSink.tryEmitComplete(); + }).doOnError(e -> { + if (!isClosing) { + logger.error("Error in outbound processing", e); + isClosing = true; + outboundSink.tryEmitComplete(); + } + }).subscribe(); + } + + /** + * Gracefully closes the transport by destroying the process and disposing of the + * schedulers. This method sends a TERM signal to the process and waits for it to exit + * before cleaning up resources. + * @return A Mono that completes when the transport is closed + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing = true; + logger.debug("Initiating graceful shutdown"); + }).then(Mono.defer(() -> { + // First complete all sinks to stop accepting new messages + inboundSink.tryEmitComplete(); + outboundSink.tryEmitComplete(); + errorSink.tryEmitComplete(); + + // Give a short time for any pending messages to be processed + return Mono.delay(Duration.ofMillis(100)).then(); + })).then(Mono.fromRunnable(() -> { + try { + outboundScheduler.dispose(); + if (this.clientChannel != null) { + this.clientChannel.close(); + this.clientChannel = null; + } + logger.debug("Graceful shutdown completed"); + } + catch (Exception e) { + logger.error("Error during graceful shutdown", e); + } + })).then().subscribeOn(Schedulers.boundedElastic()); + } + + public Sinks.Many getErrorSink() { + return this.errorSink; + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return this.objectMapper.convertValue(data, typeRef); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProvider.java new file mode 100644 index 00000000..84d69cd0 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProvider.java @@ -0,0 +1,11 @@ +package io.modelcontextprotocol.server.transport; + +import java.net.UnixDomainSocketAddress; + +import io.modelcontextprotocol.spec.McpServerTransportProvider; + +public interface UdsMcpServerTransportProvider extends McpServerTransportProvider { + + UnixDomainSocketAddress getUdsAddress(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProviderImpl.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProviderImpl.java new file mode 100644 index 00000000..7554bd91 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UdsMcpServerTransportProviderImpl.java @@ -0,0 +1,260 @@ +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.channels.SelectionKey; +import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.UDSServerSocketChannel; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class UdsMcpServerTransportProviderImpl implements UdsMcpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(UdsMcpServerTransportProviderImpl.class); + + private final ObjectMapper objectMapper; + + private UDSMcpSessionTransport transport; + + private McpServerSession session; + + private final AtomicBoolean isClosing = new AtomicBoolean(false); + + private final Sinks.One inboundReady = Sinks.one(); + + private final Sinks.One outboundReady = Sinks.one(); + + private UnixDomainSocketAddress targetAddress; + + public UnixDomainSocketAddress getUdsAddress() { + return targetAddress; + } + + /** + * Creates a new UdsMcpServerTransportProviderImpl with a default ObjectMapper + * @param unixSocketAddress the UDS socket address to bind to. Must not be null. + */ + public UdsMcpServerTransportProviderImpl(UnixDomainSocketAddress unixSocketAddress) { + this(new ObjectMapper(), unixSocketAddress); + } + + /** + * Creates a new UdsMcpServerTransportProviderImpl with the specified ObjectMapper + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + */ + public UdsMcpServerTransportProviderImpl(ObjectMapper objectMapper, UnixDomainSocketAddress unixSocketAddress) { + Assert.notNull(objectMapper, "objectMapper cannot be null"); + this.objectMapper = objectMapper; + Assert.notNull(unixSocketAddress, "unixSocketAddress cannot be null"); + this.targetAddress = unixSocketAddress; + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.transport = new UDSMcpSessionTransport(); + this.session = sessionFactory.create(transport); + this.transport.initProcessing(); + } + + @Override + public Mono notifyClients(String method, Object params) { + if (this.session == null) { + return Mono.error(new Exception("No session to close")); + } + return this.session.sendNotification(method, params) + .doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage())); + } + + @Override + public Mono closeGracefully() { + if (this.session == null) { + return Mono.empty(); + } + return this.session.closeGracefully(); + } + + /** + * Implementation of McpServerTransport for the uds session. + */ + private class UDSMcpSessionTransport implements McpServerTransport { + + private final Sinks.Many inboundSink; + + private final Sinks.Many outboundSink; + + /** Scheduler for handling outbound messages */ + private Scheduler outboundScheduler; + + private final AtomicBoolean isStarted = new AtomicBoolean(false); + + private final UDSServerSocketChannel serverSocketChannel; + + public UDSMcpSessionTransport() { + this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), + "uds-outbound"); + try { + this.serverSocketChannel = new UDSServerSocketChannel() { + @Override + protected void handleException(SelectionKey key, Throwable e) { + isClosing.set(true); + if (session != null) { + session.close(); + session = null; + } + inboundSink.tryEmitComplete(); + } + }; + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.zip(inboundReady.asMono(), outboundReady.asMono()).then(Mono.defer(() -> { + outboundSink.emitNext(message, (signalType, emitResult) -> { + // Allow retry + return true; + }); + return Mono.empty(); + })); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing.set(true); + logger.debug("Session transport closing gracefully"); + inboundSink.tryEmitComplete(); + }); + } + + @Override + public void close() { + isClosing.set(true); + logger.debug("Session transport closed"); + } + + private void initProcessing() { + handleIncomingMessages(); + startInboundProcessing(); + startOutboundProcessing(); + + inboundReady.tryEmitValue(null); + outboundReady.tryEmitValue(null); + } + + private void handleIncomingMessages() { + this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> { + this.outboundSink.tryEmitComplete(); + }).subscribe(); + } + + /** + * Starts the inbound processing thread that reads JSON-RPC messages from stdin. + * Messages are deserialized and passed to the session for handling. + */ + private void startInboundProcessing() { + if (isStarted.compareAndSet(false, true)) { + try { + this.serverSocketChannel.start(targetAddress, (clientChannel) -> { + if (logger.isDebugEnabled()) { + logger.debug("Accepted connect from clientChannel=" + clientChannel); + } + }, (message) -> { + if (logger.isDebugEnabled()) { + logger.debug("Received message=" + message); + } + // Incoming messages processed right here + McpSchema.JSONRPCMessage jsonMessage = McpSchema.deserializeJsonRpcMessage(objectMapper, + message); + if (!this.inboundSink.tryEmitNext(jsonMessage).isSuccess()) { + throw new IOException("Error adding jsonMessge to inboundSink"); + } + }); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + /** + * Starts the outbound processing thread that writes JSON-RPC messages to stdout. + * Messages are serialized to JSON and written with a newline delimiter. + */ + private void startOutboundProcessing() { + Function, Flux> outboundConsumer = messages -> messages // @formatter:off + .doOnSubscribe(subscription -> outboundReady.tryEmitValue(null)) + .publishOn(outboundScheduler) + .handle((message, sink) -> { + if (message != null && !isClosing.get()) { + try { + serverSocketChannel.writeMessage(objectMapper.writeValueAsString(message)); + sink.next(message); + } + catch (IOException e) { + if (!isClosing.get()) { + logger.error("Error writing message", e); + sink.error(new RuntimeException(e)); + } + else { + logger.debug("Stream closed during shutdown", e); + } + } + } + else if (isClosing.get()) { + sink.complete(); + } + }) + .doOnComplete(() -> { + isClosing.set(true); + outboundScheduler.dispose(); + }) + .doOnError(e -> { + if (!isClosing.get()) { + logger.error("Error in outbound processing", e); + isClosing.set(true); + outboundScheduler.dispose(); + } + }) + .map(msg -> (JSONRPCMessage) msg); + + outboundConsumer.apply(outboundSink.asFlux()).subscribe(); + } // @formatter:on + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java new file mode 100644 index 00000000..36501502 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/AbstractSocketChannel.java @@ -0,0 +1,355 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.io.InterruptedIOException; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.nio.charset.StandardCharsets; +import java.util.Iterator; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class AbstractSocketChannel { + + private static final Logger logger = LoggerFactory.getLogger(AbstractSocketChannel.class); + + public static final int DEFAULT_INBUFFER_SIZE = 1024; + + public static String DEFAULT_MESSAGE_DELIMITER = "\n"; + + protected String messageDelimiter = DEFAULT_MESSAGE_DELIMITER; + + protected void setMessageDelimiter(String delim) { + this.messageDelimiter = delim; + } + + public static int DEFAULT_WRITE_TIMEOUT = 5000; // ms + + protected int writeTimeout = DEFAULT_WRITE_TIMEOUT; + + protected void setWriteTimeout(int timeout) { + this.writeTimeout = timeout; + } + + public static int DEFAULT_CONNECT_TIMEOUT = 10000; // ms + + protected int connectTimeout = DEFAULT_CONNECT_TIMEOUT; + + protected void setConnectTimeout(int timeout) { + this.connectTimeout = timeout; + } + + public static int DEFAULT_TERMINATION_TIMEOUT = 2000; // ms + + protected int terminationTimeout = DEFAULT_TERMINATION_TIMEOUT; + + protected void setTerminationTimeout(int timeout) { + this.terminationTimeout = timeout; + } + + protected final Selector selector; + + protected final ByteBuffer inBuffer; + + protected final ExecutorService executor; + + private final Object writeLock = new Object(); + + @FunctionalInterface + public interface IOConsumer { + + void apply(T t) throws IOException; + + } + + protected class AttachedIO { + + public ByteBuffer writing; + + public StringBuffer reading; + + } + + public AbstractSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + Assert.notNull(selector, "Selector must not be null"); + this.selector = selector; + this.inBuffer = ByteBuffer.allocate(incomingBufferSize); + this.executor = (executor == null) ? Executors.newSingleThreadExecutor() : executor; + } + + public AbstractSocketChannel(Selector selector, int incomingBufferSize) { + this(selector, incomingBufferSize, null); + } + + public AbstractSocketChannel(Selector selector) { + this(selector, DEFAULT_INBUFFER_SIZE); + } + + public AbstractSocketChannel() throws IOException { + this(Selector.open()); + } + + protected Runnable getRunnableForProcessing(IOConsumer acceptHandler, + IOConsumer connectHandler, IOConsumer readHandler) { + return () -> { + SelectionKey key = null; + try { + while (true) { + int count = this.selector.select(); + debug("select returned count={}", count); + Set selectedKeys = selector.selectedKeys(); + Iterator iter = selectedKeys.iterator(); + while (iter.hasNext()) { + key = iter.next(); + if (key.isConnectable()) { + handleConnectable(key, connectHandler); + } + else if (key.isAcceptable()) { + handleAcceptable(key, acceptHandler); + } + else if (key.isReadable()) { + handleReadable(key, readHandler); + } + else if (key.isWritable()) { + handleWritable(key); + } + iter.remove(); + } + } + } + catch (Throwable e) { + handleException(key, e); + } + }; + } + + public abstract void close(); + + protected abstract void handleException(SelectionKey key, Throwable e); + + protected void start(IOConsumer acceptHandler, IOConsumer connectHandler, + IOConsumer readHandler) throws IOException { + this.executor.execute(getRunnableForProcessing(acceptHandler, connectHandler, readHandler)); + } + + protected void debug(String format, Object... o) { + if (logger.isDebugEnabled()) { + logger.debug(format, o); + } + } + + // For client subclasses + protected void handleConnectable(SelectionKey key, IOConsumer connectHandler) throws IOException { + SocketChannel client = (SocketChannel) key.channel(); + debug("client={}", client); + client.configureBlocking(false); + client.register(this.selector, SelectionKey.OP_READ, new AttachedIO()); + if (client.isConnectionPending()) { + client.finishConnect(); + debug("connected client={}", client); + } + if (connectHandler != null) { + connectHandler.apply(client); + } + } + + protected void handleAcceptable(SelectionKey key, IOConsumer acceptHandler) throws IOException { + ServerSocketChannel serverSocket = (ServerSocketChannel) key.channel(); + SocketChannel client = serverSocket.accept(); + debug("client={}", client); + client.configureBlocking(false); + client.register(this.selector, SelectionKey.OP_READ, new AttachedIO()); + configureAcceptSocketChannel(client); + if (client.isConnectionPending()) { + client.finishConnect(); + debug("accepted client={}", client); + } + if (acceptHandler != null) { + acceptHandler.apply(client); + } + } + + protected void configureAcceptSocketChannel(SocketChannel client) throws IOException { + // Subclasses may override + } + + protected AttachedIO getAttachedIO(SelectionKey key) throws IOException { + AttachedIO io = (AttachedIO) key.attachment(); + if (io == null) { + throw new IOException("No AttachedIO object found on key"); + } + return io; + } + + protected void handleReadable(SelectionKey key, IOConsumer readHandler) throws IOException { + SocketChannel client = (SocketChannel) key.channel(); + AttachedIO io = getAttachedIO(key); + debug("read client={}", client); + // read + int r = client.read(this.inBuffer); + // Check if we should expect any more reads + if (r == -1) { + throw new IOException("Channel read reached end of stream"); + } + this.inBuffer.flip(); + String partial = new String(this.inBuffer.array(), 0, r, StandardCharsets.UTF_8); + // If there is previous partial, get the io.reading string Buffer + StringBuffer sb = (io.reading != null) ? (StringBuffer) io.reading : new StringBuffer(); + // append the just read partial to the existing or new string buffer + sb.append(partial); + if (partial.endsWith(messageDelimiter)) { + // Get the entire message from the string buffer + String message = sb.toString(); + // Set the io.reading value to null as we are done with this message + io.reading = null; + debug("read client={} msg=", client, message); + if (readHandler != null) { + String[] messages = splitMessage(message); + for (int i = 0; i < messages.length; i++) { + readHandler.apply(messages[i]); + } + } + } + else { + io.reading = sb; + debug("read partial={}", partial); + } + // Clear inbuffer for next read + this.inBuffer.clear(); + } + + protected void handleWritable(SelectionKey key) throws IOException { + ByteBuffer buf = getAttachedIO(key).writing; + SocketChannel client = (SocketChannel) key.channel(); + if (buf != null) { + doWrite(key, client, buf, (o) -> { + synchronized (writeLock) { + writeLock.notifyAll(); + } + }); + } + } + + protected void doWrite(SocketChannel client, String message, IOConsumer writeHandler) throws IOException { + Assert.notNull(client, "Client must not be null"); + Assert.notNull(message, "Message must not be null"); + doWrite(client.keyFor(this.selector), client, ByteBuffer.wrap(message.getBytes(StandardCharsets.UTF_8)), + writeHandler); + } + + protected void doWrite(SelectionKey key, SocketChannel client, ByteBuffer buf, IOConsumer writeHandler) + throws IOException { + AttachedIO io = (AttachedIO) key.attachment(); + synchronized (writeLock) { + int written = client.write(buf); + if (buf.hasRemaining()) { + debug("doWrite written={}, remaining={}", written, buf.remaining()); + io.writing = buf.slice(); + key.interestOpsOr(SelectionKey.OP_WRITE); + } + else { + if (logger.isDebugEnabled()) { + logger.debug("doWrite message={}", new String(buf.array(), 0, written)); + } + io.writing = null; + key.interestOps(SelectionKey.OP_READ); + if (writeHandler != null) { + writeHandler.apply(null); + } + } + } + } + + protected void executorShutdown() { + if (!this.executor.isShutdown()) { + debug("shutdown"); + try { + this.executor.awaitTermination(this.terminationTimeout, TimeUnit.MILLISECONDS); + this.executor.shutdown(); + } + catch (InterruptedException e) { + if (logger.isDebugEnabled()) { + logger.debug("Exception in executor awaitTermination", e); + } + } + } + } + + protected void hardCloseClient(SocketChannel client, IOConsumer closeHandler) { + if (client != null) { + debug("hardClose client={}", client); + synchronized (writeLock) { + try { + if (closeHandler != null) { + closeHandler.apply(client); + } + client.close(); + } + catch (IOException e) { + if (logger.isDebugEnabled()) { + logger.debug("hardClose client socketchannel.close exception", e); + } + } + } + executorShutdown(); + } + } + + protected void writeMessageToChannel(SocketChannel client, String message) throws IOException { + Objects.requireNonNull(client, "Client must not be null"); + Objects.requireNonNull(message, "Message must not be null"); + // Escape any embedded newlines in the JSON message + String outputMessage = message.replace("\r\n", "\\n") + .replace("\n", "\\n") + .replace("\r", "\\n") + // add message delimiter + .concat(DEFAULT_MESSAGE_DELIMITER); + debug("writing msg={}", outputMessage); + synchronized (writeLock) { + // do the non blocking write in thread while holding lock. + doWrite(client, outputMessage, null); + ByteBuffer bufRemaining = null; + long waitTime = System.currentTimeMillis() + this.writeTimeout; + while (waitTime - System.currentTimeMillis() > 0) { + // Before releasing lock, check for writing buffer remaining + bufRemaining = getAttachedIO(client.keyFor(this.selector)).writing; + if (bufRemaining == null || bufRemaining.remaining() == 0) { + // It's done + break; + } + // If write is *not* completed, then wait timeout /10 + try { + debug("writeBlocking WAITING(ms)={} msg={}", String.valueOf(waitTime / 10), outputMessage); + writeLock.wait(waitTime / 10); + } + catch (InterruptedException e) { + throw new InterruptedIOException("write message wait interrupted"); + } + } + if (bufRemaining != null && bufRemaining.remaining() > 0) { + throw new IOException("Write not completed. Non empty buffer remaining after timeout"); + } + } + debug("writing done msg={}", outputMessage); + } + + protected void configureConnectSocketChannel(SocketChannel client, SocketAddress connectAddress) + throws IOException { + // Subclasses may override + } + + protected String[] splitMessage(String message) { + return (message == null) ? new String[0] : message.split(messageDelimiter); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java new file mode 100644 index 00000000..edda1671 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ClientSocketChannel.java @@ -0,0 +1,100 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.SocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ClientSocketChannel extends AbstractSocketChannel { + + private static final Logger logger = LoggerFactory.getLogger(ClientSocketChannel.class); + + protected SocketChannel client; + + protected final Object connectLock = new Object(); + + public ClientSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public ClientSocketChannel() throws IOException { + super(); + } + + public ClientSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public ClientSocketChannel(Selector selector) { + super(selector); + } + + protected SocketChannel doConnect(SocketChannel client, SocketAddress address, + IOConsumer connectHandler, IOConsumer readHandler) throws IOException { + debug("connect targetAddress={}", address); + client.configureBlocking(false); + client.register(selector, SelectionKey.OP_CONNECT); + configureConnectSocketChannel(client, address); + // Start the read thread before connect + // No/null accept handler for clients + start(null, (c) -> { + synchronized (connectLock) { + if (connectHandler != null) { + connectHandler.apply(c); + } + connectLock.notifyAll(); + } + }, readHandler); + + client.connect(address); + try { + debug("connect targetAddress={}", address); + synchronized (connectLock) { + connectLock.wait(this.connectTimeout); + } + } + catch (InterruptedException e) { + throw new IOException( + "Connect to address=" + address + " timed out after " + String.valueOf(this.connectTimeout) + "ms"); + } + debug("connected client={}", client); + return client; + } + + public void connect(StandardProtocolFamily protocol, SocketAddress address, + IOConsumer connectHandler, IOConsumer readHandler) throws IOException { + if (this.client != null) { + throw new IOException("Already connected"); + } + this.client = doConnect(SocketChannel.open(protocol), address, connectHandler, readHandler); + } + + @Override + protected void handleException(SelectionKey key, Throwable e) { + if (logger.isDebugEnabled()) { + logger.debug("handleException", e); + } + close(); + } + + @Override + public void close() { + hardCloseClient(this.client, (client) -> { + this.client = null; + }); + } + + public void writeMessage(String message) throws IOException { + if (this.client == null) { + throw new IOException("Cannot write until client connected"); + } + writeMessageToChannel(client, message); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java new file mode 100644 index 00000000..0cdfd4ee --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ServSocketChannel.java @@ -0,0 +1,92 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.SocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ServSocketChannel extends AbstractSocketChannel { + + private static final Logger logger = LoggerFactory.getLogger(ServSocketChannel.class); + + protected SocketChannel acceptedClient; + + public ServSocketChannel() throws IOException { + super(); + } + + public ServSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public ServSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public ServSocketChannel(Selector selector) { + super(selector); + } + + protected void configureServerSocketChannel(ServerSocketChannel serverSocketChannel, SocketAddress acceptAddress) { + // Subclasses may override + } + + public void start(StandardProtocolFamily protocol, SocketAddress address, IOConsumer acceptHandler, + IOConsumer readHandler) throws IOException { + ServerSocketChannel serverChannel = ServerSocketChannel.open(protocol); + serverChannel.configureBlocking(false); + serverChannel.register(this.selector, SelectionKey.OP_ACCEPT); + configureServerSocketChannel(serverChannel, address); + serverChannel.bind(address); + // Start thread/processing of incoming accept, read + super.start((client) -> { + if (logger.isDebugEnabled()) { + logger.debug("Setting client=" + client); + } + this.acceptedClient = client; + if (acceptHandler != null) { + acceptHandler.apply(this.acceptedClient); + } + // No/null connect handler for Acceptors...only accepthandler + }, null, readHandler); + } + + @Override + protected void handleException(SelectionKey key, Throwable e) { + if (logger.isDebugEnabled()) { + logger.debug("handleException", e); + } + close(); + } + + public void writeMessage(String message) throws IOException { + SocketChannel c = this.acceptedClient; + if (c != null) { + writeMessageToChannel(c, message); + } + else { + throw new IOException("not connected"); + } + } + + @Override + public void close() { + SocketChannel client = this.acceptedClient; + if (client != null) { + hardCloseClient(client, (c) -> { + if (logger.isDebugEnabled()) { + logger.debug("Unsetting client=" + c); + } + this.acceptedClient = null; + }); + } + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientSocketChannel.java new file mode 100644 index 00000000..93539c85 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientSocketChannel.java @@ -0,0 +1,33 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.StandardProtocolFamily; +import java.net.UnixDomainSocketAddress; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class UDSClientSocketChannel extends ClientSocketChannel { + + public UDSClientSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public UDSClientSocketChannel() throws IOException { + super(); + } + + public UDSClientSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public UDSClientSocketChannel(Selector selector) { + super(selector); + } + + public void connect(UnixDomainSocketAddress address, IOConsumer connectHandler, + IOConsumer readHandler) throws IOException { + super.connect(StandardProtocolFamily.UNIX, address, connectHandler, readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerSocketChannel.java new file mode 100644 index 00000000..a6607cf1 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerSocketChannel.java @@ -0,0 +1,33 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.StandardProtocolFamily; +import java.net.UnixDomainSocketAddress; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class UDSServerSocketChannel extends ServSocketChannel { + + public UDSServerSocketChannel() throws IOException { + super(); + } + + public UDSServerSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public UDSServerSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public UDSServerSocketChannel(Selector selector) { + super(selector); + } + + public void start(UnixDomainSocketAddress address, IOConsumer acceptHandler, + IOConsumer readHandler) throws IOException { + super.start(StandardProtocolFamily.UNIX, address, acceptHandler, readHandler); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java new file mode 100644 index 00000000..701ad3f0 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java @@ -0,0 +1,65 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.client.transport.UdsMcpClientTransportImpl; +import io.modelcontextprotocol.server.TestEverythingServer; +import io.modelcontextprotocol.server.transport.UdsMcpServerTransportProviderImpl; +import io.modelcontextprotocol.spec.McpClientTransport; + +/** + * Tests for the {@link McpAyncClient} with {@link UDSClientTransport}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @author Scott Lewis + */ +@Timeout(150) // Giving extra time beyond the client timeout +class UDSMcpAsyncClientTests extends AbstractMcpAsyncClientTests { + + private Path socketPath = Paths.get(getClass().getName() + ".unix.socket"); + + private void deleteSocketPath() { + try { + Files.deleteIfExists(socketPath); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected void onStart() { + super.onStart(); + deleteSocketPath(); + this.server = new TestEverythingServer( + new UdsMcpServerTransportProviderImpl(UnixDomainSocketAddress.of(socketPath))); + } + + @Override + protected void onClose() { + super.onClose(); + if (server != null) { + server.closeGracefully(); + server = null; + } + deleteSocketPath(); + } + + private TestEverythingServer server; + + @Override + protected McpClientTransport createMcpTransport() { + return new UdsMcpClientTransportImpl(UnixDomainSocketAddress.of(socketPath)); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java new file mode 100644 index 00000000..93e19f90 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java @@ -0,0 +1,70 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.Duration; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.client.transport.UdsMcpClientTransportImpl; +import io.modelcontextprotocol.server.TestEverythingServer; +import io.modelcontextprotocol.server.transport.UdsMcpServerTransportProviderImpl; +import io.modelcontextprotocol.spec.McpClientTransport; + +/** + * Tests for the {@link McpSyncClient} with {@link UDSClientTransport}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @author Scott Lewis + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpSyncClientTests extends AbstractMcpSyncClientTests { + + private Path socketPath = Paths.get(getClass().getName() + ".unix.socket"); + + private void deleteSocketPath() { + try { + Files.deleteIfExists(socketPath); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected void onStart() { + super.onStart(); + deleteSocketPath(); + this.server = new TestEverythingServer( + new UdsMcpServerTransportProviderImpl(UnixDomainSocketAddress.of(socketPath))); + } + + @Override + protected void onClose() { + super.onClose(); + if (server != null) { + server.closeGracefully(); + server = null; + } + deleteSocketPath(); + } + + private TestEverythingServer server; + + @Override + protected McpClientTransport createMcpTransport() { + return new UdsMcpClientTransportImpl(UnixDomainSocketAddress.of(socketPath)); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/TestEverythingServer.java b/mcp/src/test/java/io/modelcontextprotocol/server/TestEverythingServer.java new file mode 100644 index 00000000..6d5fcf9b --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/TestEverythingServer.java @@ -0,0 +1,148 @@ +package io.modelcontextprotocol.server; + +import java.util.List; + +import io.modelcontextprotocol.server.McpServerFeatures.SyncPromptSpecification; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema.Annotations; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest.ContextInclusionStrategy; + +public class TestEverythingServer { + + private static final String TEST_RESOURCE_URI = "test://resources/"; + + private static final String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + private McpSyncServer server; + + public TestEverythingServer(McpServerTransportProvider transport) { + McpServerFeatures.SyncResourceSpecification[] specs = new McpServerFeatures.SyncResourceSpecification[10]; + for (int i = 0; i < 10; i++) { + String istr = String.valueOf(i); + String uri = TEST_RESOURCE_URI + istr; + specs[i] = new McpServerFeatures.SyncResourceSpecification( + Resource.builder() + .uri(uri) + .name("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(), + (exchange, + req) -> new ReadResourceResult(List.of(new TextResourceContents(uri, "text/plain", istr)))); + } + + this.server = McpServer.sync(transport) + .serverInfo(getClass().getName() + "-server", "1.0.0") + .capabilities( + ServerCapabilities.builder().logging().tools(true).prompts(true).resources(true, true).build()) + .toolCall(Tool.builder() + .name("echo") + .description("echo tool description") + .inputSchema(emptyJsonSchema) + .build(), (exchange, request) -> { + return CallToolResult.builder().addTextContent((String) request.arguments().get("message")).build(); + }) + .toolCall(Tool.builder().name("add").description("add two integers").inputSchema(emptyJsonSchema).build(), + (exchange, request) -> { + Integer a = (Integer) request.arguments().get("a"); + Integer b = (Integer) request.arguments().get("b"); + + return CallToolResult.builder().addTextContent(String.valueOf(a + b)).build(); + }) + .toolCall( + Tool.builder().name("sampleLLM").description("sampleLLM tool").inputSchema(emptyJsonSchema).build(), + (exchange, request) -> { + String prompt = (String) request.arguments().get("prompt"); + Integer maxTokens = (Integer) request.arguments().get("maxTokens"); + SamplingMessage sm = new SamplingMessage(McpSchema.Role.USER, + new TextContent("Resource sampleLLM context: " + prompt)); + CreateMessageRequest cmRequest = CreateMessageRequest.builder() + .messages(List.of(sm)) + .systemPrompt("You are a helpful test server.") + .maxTokens(maxTokens) + .temperature(0.7) + .includeContext(ContextInclusionStrategy.THIS_SERVER) + .build(); + CreateMessageResult result = exchange.createMessage(cmRequest); + + return CallToolResult.builder() + .addTextContent("LLM sampling result: " + ((TextContent) result.content()).text()) + .build(); + }) + .toolCall(Tool.builder() + .name("longRunningOperation") + .description("Demonstrates a long running operation with progress updates") + .inputSchema(emptyJsonSchema) + .build(), (exchange, request) -> { + String progressToken = (String) request.progressToken(); + int steps = (Integer) request.arguments().get("steps"); + for (int i = 0; i < steps; i++) { + try { + Thread.sleep(1000); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + if (progressToken != null) { + exchange.progressNotification(new ProgressNotification(progressToken, (double) i + 1, + (double) steps, "progress message " + String.valueOf(i + 1))); + } + } + return CallToolResult.builder().content(List.of(new TextContent("done"))).build(); + }) + .toolCall(Tool.builder().name("annotatedMessage").description("annotated message").build(), + (exchange, request) -> { + String messageType = (String) request.arguments().get("messageType"); + Annotations annotations = null; + if (messageType.equals("success")) { + annotations = new Annotations(List.of(McpSchema.Role.USER), 0.7); + } + else if (messageType.equals("error")) { + annotations = new Annotations(List.of(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 1.0); + } + else if (messageType.equals("debug")) { + annotations = new Annotations(List.of(McpSchema.Role.ASSISTANT), 0.3); + } + return CallToolResult.builder() + .addContent(new TextContent(annotations, "some response")) + .build(); + }) + .prompts(List.of(new SyncPromptSpecification(new Prompt("simple_prompt", "Simple prompt description", null), + (exchange, request) -> { + return new GetPromptResult("description", + List.of(new PromptMessage(Role.USER, new TextContent("hello")))); + }))) + .resources(specs) + .build(); + } + + public void closeGracefully() { + if (this.server != null) { + this.server.closeGracefully(); + this.server = null; + } + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java new file mode 100644 index 00000000..74b088a8 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java @@ -0,0 +1,58 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.server.transport.UdsMcpServerTransportProviderImpl; +import io.modelcontextprotocol.spec.McpServerTransportProvider; + +/** + * Tests for {@link McpAsyncServer} using {@link UDSServerTransport}. + * + * @author Christian Tzolov + * @author Scott Lewis + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + + private Path socketPath = Paths.get(getClass().getName() + ".unix.socket"); + + private void deleteSocketPath() { + try { + Files.deleteIfExists(socketPath); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected void onStart() { + super.onStart(); + deleteSocketPath(); + } + + @Override + protected void onClose() { + super.onClose(); + deleteSocketPath(); + } + + protected McpServerTransportProvider createMcpTransportProvider() { + return new UdsMcpServerTransportProviderImpl(UnixDomainSocketAddress.of(socketPath)); + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java new file mode 100644 index 00000000..e36febaf --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java @@ -0,0 +1,58 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.server.transport.UdsMcpServerTransportProviderImpl; +import io.modelcontextprotocol.spec.McpServerTransportProvider; + +/** + * Tests for {@link McpSyncServer} using {@link UdsMcpServerTransportProviderImpl}. + * + * @author Christian Tzolov + * @author Scott Lewis + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpSyncServerTests extends AbstractMcpSyncServerTests { + + private Path socketPath = Paths.get(getClass().getName() + ".unix.socket"); + + private void deleteSocketPath() { + try { + Files.deleteIfExists(socketPath); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected void onStart() { + super.onStart(); + deleteSocketPath(); + } + + @Override + protected void onClose() { + super.onClose(); + deleteSocketPath(); + } + + protected McpServerTransportProvider createMcpTransportProvider() { + return new UdsMcpServerTransportProviderImpl(UnixDomainSocketAddress.of(socketPath)); + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + +}