diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java index 6fca357b08..79c21f3335 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java @@ -40,7 +40,6 @@ import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO; import static com.mongodb.internal.connection.CommandHelper.executeCommand; import static com.mongodb.internal.connection.CommandHelper.executeCommandAsync; -import static com.mongodb.internal.connection.CommandHelper.executeCommandWithoutCheckingForFailure; import static com.mongodb.internal.connection.DefaultAuthenticator.USER_NOT_FOUND_CODE; import static com.mongodb.internal.connection.DescriptionHelper.createConnectionDescription; import static com.mongodb.internal.connection.DescriptionHelper.createServerDescription; @@ -88,7 +87,8 @@ public InternalConnectionInitializationDescription finishHandshake(final Interna if (Authenticator.shouldAuthenticate(authenticator, connectionDescription)) { authenticator.authenticate(internalConnection, connectionDescription, operationContext); } - return completeConnectionDescriptionInitialization(internalConnection, description, operationContext); + + return description; } @Override @@ -121,14 +121,14 @@ public void finishHandshakeAsync(final InternalConnection internalConnection, ConnectionDescription connectionDescription = description.getConnectionDescription(); if (!Authenticator.shouldAuthenticate(authenticator, connectionDescription)) { - completeConnectionDescriptionInitializationAsync(internalConnection, description, operationContext, callback); + callback.onResult(description, null); } else { authenticator.authenticateAsync(internalConnection, connectionDescription, operationContext, (result1, t1) -> { if (t1 != null) { callback.onResult(null, t1); } else { - completeConnectionDescriptionInitializationAsync(internalConnection, description, operationContext, callback); + callback.onResult(description, null); } }); } @@ -203,21 +203,6 @@ private BsonDocument createHelloCommand(final Authenticator authenticator, final return helloCommandDocument; } - private InternalConnectionInitializationDescription completeConnectionDescriptionInitialization( - final InternalConnection internalConnection, - final InternalConnectionInitializationDescription description, - final OperationContext operationContext) { - - if (description.getConnectionDescription().getConnectionId().getServerValue() != null) { - return description; - } - - return applyGetLastErrorResult(executeCommandWithoutCheckingForFailure("admin", - new BsonDocument("getlasterror", new BsonInt32(1)), clusterConnectionMode, serverApi, - internalConnection, operationContext), - description); - } - private void setSpeculativeAuthenticateResponse(final BsonDocument helloResult) { if (authenticator instanceof SpeculativeAuthenticator) { ((SpeculativeAuthenticator) authenticator).setSpeculativeAuthenticateResponse( @@ -225,28 +210,6 @@ private void setSpeculativeAuthenticateResponse(final BsonDocument helloResult) } } - private void completeConnectionDescriptionInitializationAsync( - final InternalConnection internalConnection, - final InternalConnectionInitializationDescription description, - final OperationContext operationContext, - final SingleResultCallback callback) { - - if (description.getConnectionDescription().getConnectionId().getServerValue() != null) { - callback.onResult(description, null); - return; - } - - executeCommandAsync("admin", new BsonDocument("getlasterror", new BsonInt32(1)), clusterConnectionMode, serverApi, - internalConnection, operationContext, - (result, t) -> { - if (t != null) { - callback.onResult(description, null); - } else { - callback.onResult(applyGetLastErrorResult(result, description), null); - } - }); - } - private InternalConnectionInitializationDescription applyGetLastErrorResult( final BsonDocument getLastErrorResult, final InternalConnectionInitializationDescription description) { diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/InternalStreamConnectionInitializerSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/InternalStreamConnectionInitializerSpecification.groovy index 93bc656226..156499797c 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/InternalStreamConnectionInitializerSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/InternalStreamConnectionInitializerSpecification.groovy @@ -61,14 +61,14 @@ class InternalStreamConnectionInitializerSpecification extends Specification { def initializer = new InternalStreamConnectionInitializer(SINGLE, null, null, [], null) when: - enqueueSuccessfulReplies(false, null) + enqueueSuccessfulReplies(false, 123) def description = initializer.startHandshake(internalConnection, operationContext) description = initializer.finishHandshake(internalConnection, description, operationContext) def connectionDescription = description.connectionDescription def serverDescription = description.serverDescription then: - connectionDescription == getExpectedConnectionDescription(connectionDescription.connectionId.localValue, null) + connectionDescription == getExpectedConnectionDescription(connectionDescription.connectionId.localValue, 123) serverDescription == getExpectedServerDescription(serverDescription) } @@ -77,7 +77,7 @@ class InternalStreamConnectionInitializerSpecification extends Specification { def initializer = new InternalStreamConnectionInitializer(SINGLE, null, null, [], null) when: - enqueueSuccessfulReplies(false, null) + enqueueSuccessfulReplies(false, 123) def futureCallback = new FutureResultCallback() initializer.startHandshakeAsync(internalConnection, operationContext, futureCallback) def description = futureCallback.get() @@ -88,7 +88,7 @@ class InternalStreamConnectionInitializerSpecification extends Specification { def serverDescription = description.serverDescription then: - connectionDescription == getExpectedConnectionDescription(connectionDescription.connectionId.localValue, null) + connectionDescription == getExpectedConnectionDescription(connectionDescription.connectionId.localValue, 123) serverDescription == getExpectedServerDescription(serverDescription) } @@ -106,20 +106,6 @@ class InternalStreamConnectionInitializerSpecification extends Specification { connectionDescription == getExpectedConnectionDescription(connectionDescription.connectionId.localValue, 123) } - def 'should create correct description with server connection id from hello'() { - given: - def initializer = new InternalStreamConnectionInitializer(SINGLE, null, null, [], null) - - when: - enqueueSuccessfulRepliesWithConnectionIdIsHelloResponse(false, 123) - def internalDescription = initializer.startHandshake(internalConnection, operationContext) - def connectionDescription = initializer.finishHandshake(internalConnection, internalDescription, operationContext) - .connectionDescription - - then: - connectionDescription == getExpectedConnectionDescription(connectionDescription.connectionId.localValue, 123) - } - def 'should create correct description with server connection id asynchronously'() { given: def initializer = new InternalStreamConnectionInitializer(SINGLE, null, null, [], null) @@ -137,31 +123,13 @@ class InternalStreamConnectionInitializerSpecification extends Specification { connectionDescription == getExpectedConnectionDescription(connectionDescription.connectionId.localValue, 123) } - def 'should create correct description with server connection id from hello asynchronously'() { - given: - def initializer = new InternalStreamConnectionInitializer(SINGLE, null, null, [], null) - - when: - enqueueSuccessfulRepliesWithConnectionIdIsHelloResponse(false, 123) - def futureCallback = new FutureResultCallback() - initializer.startHandshakeAsync(internalConnection, operationContext, futureCallback) - def description = futureCallback.get() - futureCallback = new FutureResultCallback() - initializer.finishHandshakeAsync(internalConnection, description, operationContext, futureCallback) - description = futureCallback.get() - def connectionDescription = description.connectionDescription - - then: - connectionDescription == getExpectedConnectionDescription(connectionDescription.connectionId.localValue, 123) - } - def 'should authenticate'() { given: def firstAuthenticator = Mock(Authenticator) def initializer = new InternalStreamConnectionInitializer(SINGLE, firstAuthenticator, null, [], null) when: - enqueueSuccessfulReplies(false, null) + enqueueSuccessfulReplies(false, 123) def internalDescription = initializer.startHandshake(internalConnection, operationContext) def connectionDescription = initializer.finishHandshake(internalConnection, internalDescription, operationContext) @@ -178,7 +146,7 @@ class InternalStreamConnectionInitializerSpecification extends Specification { def initializer = new InternalStreamConnectionInitializer(SINGLE, authenticator, null, [], null) when: - enqueueSuccessfulReplies(false, null) + enqueueSuccessfulReplies(false, 123) def futureCallback = new FutureResultCallback() initializer.startHandshakeAsync(internalConnection, operationContext, futureCallback) @@ -198,7 +166,7 @@ class InternalStreamConnectionInitializerSpecification extends Specification { def initializer = new InternalStreamConnectionInitializer(SINGLE, authenticator, null, [], null) when: - enqueueSuccessfulReplies(true, null) + enqueueSuccessfulReplies(true, 123) def internalDescription = initializer.startHandshake(internalConnection, operationContext) def connectionDescription = initializer.finishHandshake(internalConnection, internalDescription, operationContext) @@ -215,7 +183,7 @@ class InternalStreamConnectionInitializerSpecification extends Specification { def initializer = new InternalStreamConnectionInitializer(SINGLE, authenticator, null, [], null) when: - enqueueSuccessfulReplies(true, null) + enqueueSuccessfulReplies(true, 123) def futureCallback = new FutureResultCallback() initializer.startHandshakeAsync(internalConnection, operationContext, futureCallback) @@ -240,7 +208,7 @@ class InternalStreamConnectionInitializerSpecification extends Specification { } when: - enqueueSuccessfulReplies(false, null) + enqueueSuccessfulReplies(false, 123) if (async) { def callback = new FutureResultCallback() initializer.startHandshakeAsync(internalConnection, operationContext, callback) @@ -277,7 +245,7 @@ class InternalStreamConnectionInitializerSpecification extends Specification { } when: - enqueueSuccessfulReplies(false, null) + enqueueSuccessfulReplies(false, 123) if (async) { def callback = new FutureResultCallback() initializer.startHandshakeAsync(internalConnection, operationContext, callback) @@ -477,25 +445,12 @@ class InternalStreamConnectionInitializerSpecification extends Specification { } def enqueueSuccessfulReplies(final boolean isArbiter, final Integer serverConnectionId) { - internalConnection.enqueueReply(buildSuccessfulReply( - '{ok: 1, ' + - 'maxWireVersion: 3' + - (isArbiter ? ', isreplicaset: true, arbiterOnly: true' : '') + - '}')) - internalConnection.enqueueReply(buildSuccessfulReply( - '{ok: 1 ' + - (serverConnectionId == null ? '' : ', connectionId: ' + serverConnectionId) + - '}')) - } - - def enqueueSuccessfulRepliesWithConnectionIdIsHelloResponse(final boolean isArbiter, final Integer serverConnectionId) { internalConnection.enqueueReply(buildSuccessfulReply( '{ok: 1, ' + 'maxWireVersion: 3,' + 'connectionId: ' + serverConnectionId + (isArbiter ? ', isreplicaset: true, arbiterOnly: true' : '') + '}')) - internalConnection.enqueueReply(buildSuccessfulReply('{ok: 1, versionArray : [3, 0, 0]}')) } def enqueueSpeculativeAuthenticationResponsesForScramSha256() {