diff --git a/src/common/container.ts b/src/common/container.ts new file mode 100644 index 00000000..696835ae --- /dev/null +++ b/src/common/container.ts @@ -0,0 +1,35 @@ +import fs from "fs/promises"; + +let containerEnv: boolean | undefined; + +export async function detectContainerEnv(): Promise { + if (containerEnv !== undefined) { + return containerEnv; + } + + const detect = async function (): Promise { + if (process.platform !== "linux") { + return false; // we only support linux containers for now + } + + if (process.env.container) { + return true; + } + + const exists = await Promise.all( + ["/.dockerenv", "/run/.containerenv", "/var/run/.containerenv"].map(async (file) => { + try { + await fs.access(file); + return true; + } catch { + return false; + } + }) + ); + + return exists.includes(true); + }; + + containerEnv = await detect(); + return containerEnv; +} diff --git a/src/logger.ts b/src/logger.ts index 8157324b..354d8956 100644 --- a/src/logger.ts +++ b/src/logger.ts @@ -180,11 +180,16 @@ class CompositeLogger extends LoggerBase { const logger = new CompositeLogger(); export default logger; -export async function initializeLogger(server: McpServer, logPath: string): Promise { +export async function setStdioPreset(server: McpServer, logPath: string): Promise { const diskLogger = await DiskLogger.fromPath(logPath); const mcpLogger = new McpLogger(server); logger.setLoggers(mcpLogger, diskLogger); +} + +export function setContainerPreset(server: McpServer): void { + const mcpLogger = new McpLogger(server); + const consoleLogger = new ConsoleLogger(); - return logger; + logger.setLoggers(mcpLogger, consoleLogger); } diff --git a/src/server.ts b/src/server.ts index b0e8e19c..31a99ded 100644 --- a/src/server.ts +++ b/src/server.ts @@ -3,7 +3,7 @@ import { Session } from "./session.js"; import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; import { AtlasTools } from "./tools/atlas/tools.js"; import { MongoDbTools } from "./tools/mongodb/tools.js"; -import logger, { initializeLogger, LogId } from "./logger.js"; +import logger, { setStdioPreset, setContainerPreset, LogId } from "./logger.js"; import { ObjectId } from "mongodb"; import { Telemetry } from "./telemetry/telemetry.js"; import { UserConfig } from "./config.js"; @@ -11,6 +11,7 @@ import { type ServerEvent } from "./telemetry/types.js"; import { type ServerCommand } from "./telemetry/types.js"; import { CallToolRequestSchema, CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import assert from "assert"; +import { detectContainerEnv } from "./common/container.js"; export interface ServerOptions { session: Session; @@ -63,7 +64,13 @@ export class Server { return existingHandler(request, extra); }); - await initializeLogger(this.mcpServer, this.userConfig.logPath); + const containerEnv = await detectContainerEnv(); + + if (containerEnv) { + setContainerPreset(this.mcpServer); + } else { + await setStdioPreset(this.mcpServer, this.userConfig.logPath); + } await this.mcpServer.connect(transport); diff --git a/src/telemetry/telemetry.ts b/src/telemetry/telemetry.ts index 5d0ad827..f1e24e20 100644 --- a/src/telemetry/telemetry.ts +++ b/src/telemetry/telemetry.ts @@ -7,7 +7,7 @@ import { MACHINE_METADATA } from "./constants.js"; import { EventCache } from "./eventCache.js"; import nodeMachineId from "node-machine-id"; import { getDeviceId } from "@mongodb-js/device-id"; -import fs from "fs/promises"; +import { detectContainerEnv } from "../common/container.js"; type EventResult = { success: boolean; @@ -53,29 +53,6 @@ export class Telemetry { return instance; } - private async isContainerEnv(): Promise { - if (process.platform !== "linux") { - return false; // we only support linux containers for now - } - - if (process.env.container) { - return true; - } - - const exists = await Promise.all( - ["/.dockerenv", "/run/.containerenv", "/var/run/.containerenv"].map(async (file) => { - try { - await fs.access(file); - return true; - } catch { - return false; - } - }) - ); - - return exists.includes(true); - } - private async setup(): Promise { if (!this.isTelemetryEnabled()) { return; @@ -98,7 +75,7 @@ export class Telemetry { }, abortSignal: this.deviceIdAbortController.signal, }), - this.isContainerEnv(), + detectContainerEnv(), ]); const [deviceId, containerEnv] = await this.setupPromise; diff --git a/tests/integration/tools/atlas/clusters.test.ts b/tests/integration/tools/atlas/clusters.test.ts index 166ee637..8bb19bda 100644 --- a/tests/integration/tools/atlas/clusters.test.ts +++ b/tests/integration/tools/atlas/clusters.test.ts @@ -1,6 +1,7 @@ import { Session } from "../../../../src/session.js"; import { expectDefined } from "../../helpers.js"; import { describeWithAtlas, withProject, randomId } from "./atlasHelpers.js"; +import { ClusterDescription20240805 } from "../../../../src/common/atlas/openapi.js"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; function sleep(ms: number) { @@ -33,7 +34,12 @@ async function deleteAndWaitCluster(session: Session, projectId: string, cluster } } -async function waitClusterState(session: Session, projectId: string, clusterName: string, state: string) { +async function waitCluster( + session: Session, + projectId: string, + clusterName: string, + check: (cluster: ClusterDescription20240805) => boolean | Promise +) { while (true) { const cluster = await session.apiClient.getCluster({ params: { @@ -43,7 +49,7 @@ async function waitClusterState(session: Session, projectId: string, clusterName }, }, }); - if (cluster?.stateName === state) { + if (await check(cluster)) { return; } await sleep(1000); @@ -142,7 +148,12 @@ describeWithAtlas("clusters", (integration) => { describe("atlas-connect-cluster", () => { beforeAll(async () => { const projectId = getProjectId(); - await waitClusterState(integration.mcpServer().session, projectId, clusterName, "IDLE"); + await waitCluster(integration.mcpServer().session, projectId, clusterName, (cluster) => { + return ( + cluster.stateName === "IDLE" && + (cluster.connectionStrings?.standardSrv || cluster.connectionStrings?.standard) !== undefined + ); + }); await integration.mcpServer().session.apiClient.createProjectIpAccessList({ params: { path: {