diff --git a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java index cd1809966b..b1579cd119 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java @@ -103,13 +103,18 @@ abstract void authenticateAsync(InternalConnection connection, ConnectionDescrip OperationContext operationContext, SingleResultCallback callback); public void reauthenticate(final InternalConnection connection, final OperationContext operationContext) { - authenticate(connection, connection.getDescription(), operationContext); + authenticate(connection, connection.getDescription(), operationContextWithoutSession(operationContext)); } public void reauthenticateAsync(final InternalConnection connection, final OperationContext operationContext, final SingleResultCallback callback) { beginAsync().thenRun((c) -> { - authenticateAsync(connection, connection.getDescription(), operationContext, c); + authenticateAsync(connection, connection.getDescription(), operationContextWithoutSession(operationContext), c); }).finish(callback); } + + private static OperationContext operationContextWithoutSession(final OperationContext operationContext) { + return operationContext.withSessionContext( + new ReadConcernAwareNoOpSessionContext(operationContext.getSessionContext().getReadConcern())); + } } diff --git a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java index b6a23a576c..2b0544f0c5 100644 --- a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java +++ b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java @@ -24,9 +24,12 @@ import com.mongodb.MongoSecurityException; import com.mongodb.MongoSocketException; import com.mongodb.assertions.Assertions; +import com.mongodb.client.ClientSession; +import com.mongodb.client.FindIterable; import com.mongodb.client.Fixture; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClients; +import com.mongodb.client.MongoCollection; import com.mongodb.client.TestListener; import com.mongodb.event.CommandListener; import com.mongodb.lang.Nullable; @@ -334,12 +337,17 @@ public void test3p3UnexpectedErrorDoesNotClearCache() { @Test public void test4p1Reauthentication() { + testReauthentication(false); + } + + private void testReauthentication(final boolean inSession) { TestCallback callback = createCallback(); MongoClientSettings clientSettings = createSettings(callback); - try (MongoClient mongoClient = createMongoClient(clientSettings)) { + try (MongoClient mongoClient = createMongoClient(clientSettings); + ClientSession session = inSession ? mongoClient.startSession() : null) { failCommand(391, 1, "find"); // #. Perform a find operation that succeeds. - performFind(mongoClient); + performFind(mongoClient, session); } assertEquals(2, callback.invocations.get()); } @@ -392,6 +400,11 @@ private static void performInsert(final MongoClient mongoClient) { .insertOne(Document.parse("{ x: 1 }")); } + @Test + public void test4p5ReauthenticationInSession() { + testReauthentication(true); + } + @Test public void test5p1AzureSucceedsWithNoUsername() { assumeAzure(); @@ -914,12 +927,14 @@ private void assertFindFails( } } - private void performFind(final MongoClient mongoClient) { - mongoClient - .getDatabase("test") - .getCollection("test") - .find() - .first(); + private static void performFind(final MongoClient mongoClient) { + performFind(mongoClient, null); + } + + private static void performFind(final MongoClient mongoClient, @Nullable final ClientSession session) { + MongoCollection collection = mongoClient.getDatabase("test").getCollection("test"); + FindIterable findIterable = session == null ? collection.find() : collection.find(session); + findIterable.first(); } protected void delayNextFind() {