Skip to content

Commit

Permalink
Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zpear committed Nov 27, 2024
1 parent 9cafd8c commit 7266db5
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 9 deletions.
7 changes: 6 additions & 1 deletion src/java/org/apache/cassandra/config/DatabaseDescriptor.java
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,12 @@ public static int getRpcListenBacklog()

public static long getInternodeConnectionTimeout()
{
return conf.request_timeout_in_ms;
return conf.internode_connect_timeout_in_ms;
}

public static long setInternodeConnectionTimeout(long timeoutInMillis)
{
return conf.internode_connect_timeout_in_ms = timeoutInMillis;
}

public static long getRpcTimeout()
Expand Down
56 changes: 48 additions & 8 deletions src/java/org/apache/cassandra/net/OutboundTcpConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.IntStream;
import java.util.zip.Checksum;

import javax.net.ssl.SSLHandshakeException;
Expand All @@ -42,6 +43,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.palantir.logsafe.SafeArg;
import net.jpountz.lz4.LZ4BlockOutputStream;
import net.jpountz.lz4.LZ4Compressor;
import net.jpountz.lz4.LZ4Factory;
Expand All @@ -65,6 +67,7 @@
import org.apache.cassandra.config.Config;
import org.apache.cassandra.config.DatabaseDescriptor;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.Uninterruptibles;

Expand Down Expand Up @@ -216,8 +219,9 @@ public void run()
//The timestamp of the first message has already been provided to the coalescing strategy
//so skip logging it.
inner:
for (QueuedMessage qm : drainedMessages)
for (int i = 0; i < drainedMessages.size(); i++)
{
QueuedMessage qm = drainedMessages.get(i);
try
{
MessageOut<?> m = qm.message;
Expand All @@ -236,8 +240,10 @@ else if (socket != null || connect())
else
{
// clear out the queue, else gossip messages back up.
drainedMessages.clear();
backlog.clear();
int cleared = clearQueueWithFailureCallback(i, drainedMessages, drainedMessageSize, backlog);
logger.warn("Failed to connect to endpoint. Cleared backlog and invoked failure callbacks",
SafeArg.of("clearedMessages", cleared),
SafeArg.of("endpoint", poolReference.endPoint()));
currentMsgBufferCount = 0;
break inner;
}
Expand All @@ -255,6 +261,24 @@ else if (socket != null || connect())
}
}

@VisibleForTesting
int clearQueueWithFailureCallback(int currentMessage, List<QueuedMessage> bufferedMessages, int bufferSize, BlockingQueue<QueuedMessage> queue) {
bufferedMessages.stream().skip(currentMessage).forEach(this::invokeFailureCallback);
int initialCleared = bufferedMessages.size() - currentMessage;
bufferedMessages.clear();

int queueSize = queue.size();
int remaining = queueSize;
while (remaining > 0) {
remaining -= queue.drainTo(bufferedMessages, Math.min(bufferSize, remaining));
for (QueuedMessage qm : bufferedMessages) {
invokeFailureCallback(qm);
}
bufferedMessages.clear();
}
return initialCleared + queueSize;
}

public int getPendingMessages()
{
return backlog.size() + currentMsgBufferCount;
Expand Down Expand Up @@ -331,10 +355,7 @@ private void writeConnected(QueuedMessage qm, boolean flush)
throw new AssertionError(e1);
}
} else {
CallbackInfo registeredCallbackInfo = MessagingService.instance().getRegisteredCallback(qm.id);
if (registeredCallbackInfo != null && registeredCallbackInfo.isFailureCallback()) {
((IAsyncCallbackWithFailure) MessagingService.instance().removeRegisteredCallback(qm.id).callback).onFailure(poolReference.endPoint());
}
invokeFailureCallback(qm);
}
}
else
Expand All @@ -345,6 +366,25 @@ private void writeConnected(QueuedMessage qm, boolean flush)
}
}

@VisibleForTesting
void invokeFailureCallback(QueuedMessage qm) {
if (qm == null) {
return;
}
CallbackInfo registeredCallbackInfo = MessagingService.instance().getRegisteredCallback(qm.id);
if (registeredCallbackInfo != null && registeredCallbackInfo.isFailureCallback()) {
Optional.ofNullable(MessagingService.instance().removeRegisteredCallback(qm.id))
.map(info -> info.callback)
.map(callback -> (IAsyncCallbackWithFailure) callback)
.ifPresent(callback -> {
logger.debug("Invoking failure callback for message",
SafeArg.of("endpoint", poolReference.endPoint()),
SafeArg.of("messageId", qm.id), SafeArg.of("verb", qm.message.verb));
callback.onFailure(poolReference.endPoint());
});
}
}

private void writeInternal(MessageOut message, int id, long timestamp) throws IOException
{
out.writeInt(MessagingService.PROTOCOL_MAGIC);
Expand Down Expand Up @@ -594,7 +634,7 @@ private void expireMessages()
}

/** messages that have not been retried yet */
private static class QueuedMessage implements Coalescable
static class QueuedMessage implements Coalescable
{
final MessageOut<?> message;
final int id;
Expand Down
174 changes: 174 additions & 0 deletions test/unit/org/apache/cassandra/net/OutboundTcpConnectionTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.cassandra.net;

import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;

import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;

import org.apache.cassandra.io.IVersionedSerializer;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.anyCollection;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

public class OutboundTcpConnectionTest
{
private OutboundTcpConnection connection;
private static OutboundTcpConnectionPool pool;
private static final InetAddress TARGET = mock(InetAddress.class);
private static final OutboundTcpConnection.QueuedMessage QM1 = new OutboundTcpConnection.QueuedMessage(new MessageOut<>(MessagingService.Verb.MUTATION), 1);
private static final OutboundTcpConnection.QueuedMessage QM2 = new OutboundTcpConnection.QueuedMessage(new MessageOut<>(MessagingService.Verb.MUTATION), 2);
private static final OutboundTcpConnection.QueuedMessage QM3 = new OutboundTcpConnection.QueuedMessage(new MessageOut<>(MessagingService.Verb.MUTATION), 3);

@BeforeClass
public static void beforeClass() throws UnknownHostException
{
pool = mock(OutboundTcpConnectionPool.class);
doReturn(InetAddress.getLocalHost()).when(pool).endPoint();
}

@Before
public void before() {
connection = spy(new OutboundTcpConnection(pool, "test"));
MessagingService.instance().clearCallbacksUnsafe();
}

@Test
public void invokeFailureCallback_ignoresNonFailureCallbacks() {
TestCallback cb = new TestCallback();
CallbackInfo nonFailureCallback = new CallbackInfo(TARGET, cb, mock(IVersionedSerializer.class), false);
MessagingService.instance().setCallbackForTests(QM1.id, nonFailureCallback);
connection.invokeFailureCallback(QM1);
assertEquals(0, cb.responses.get());
}

@Test
public void invokeFailureCallback_handlesExpiredCallback() {
assertNull(MessagingService.instance().getRegisteredCallback(QM1.id));
connection.invokeFailureCallback(QM1);
}

@Test
public void invokeFailureCallback_runsCallback() {
TestFailureCallback cb = registerFailureCallback(QM1);
connection.invokeFailureCallback(QM1);
assertEquals(0, cb.responses.get());
assertEquals(1, cb.failures.get());
assertNull(MessagingService.instance().getRegisteredCallback(QM1.id));
}

@Test
public void clearQueueWithFailureCallback_handlesInProgressDrainedList() throws InterruptedException
{
List<OutboundTcpConnection.QueuedMessage> drained = new ArrayList<>(2);
drained.add(QM1);
drained.add(QM2);
BlockingQueue<OutboundTcpConnection.QueuedMessage> backlog = new LinkedBlockingQueue<>();
backlog.put(QM3);

TestFailureCallback cb1 = registerFailureCallback(QM1);
TestFailureCallback cb2 = registerFailureCallback(QM2);
TestFailureCallback cb3 = registerFailureCallback(QM3);

connection.clearQueueWithFailureCallback(1, drained, 2, backlog);

assertEquals(0, cb1.failures.get());
assertEquals(1, cb2.failures.get());
assertEquals(1, cb3.failures.get());

assertTrue(drained.isEmpty());
assertTrue(backlog.isEmpty());
}

@Test
public void clearQueueWithFailureCallback_clearsLargeBacklog() throws InterruptedException
{
List<OutboundTcpConnection.QueuedMessage> drained = new ArrayList<>(2);
BlockingQueue<OutboundTcpConnection.QueuedMessage> backlog = spy(new LinkedBlockingQueue<>());
backlog.put(QM1);
backlog.put(QM2);
backlog.put(QM3);
backlog.put(QM3);
backlog.put(QM3);

TestFailureCallback cb1 = registerFailureCallback(QM1);
TestFailureCallback cb2 = registerFailureCallback(QM2);
TestFailureCallback cb3 = registerFailureCallback(QM3);

connection.clearQueueWithFailureCallback(0, drained, 2, backlog);
// With enough elements remaining, drain the buffer size
verify(backlog, times(2)).drainTo(anyCollection(), eq(2));
// Last call, don't take more off the backlog than needed from when we first called clearQueueWithFailureCallback
verify(backlog, times(1)).drainTo(anyCollection(), eq(1));

assertEquals(1, cb1.failures.get());
assertEquals(1, cb2.failures.get());
assertEquals(1, cb3.failures.get());

assertTrue(drained.isEmpty());
assertTrue(backlog.isEmpty());
}

static class TestCallback<T> implements IAsyncCallback<T>
{
public final AtomicInteger responses = new AtomicInteger(0);

public void response(MessageIn<T> _msg)
{
responses.incrementAndGet();

}

public boolean isLatencyForSnitch()
{
return false;
}
}

static class TestFailureCallback<T> extends TestCallback<T> implements IAsyncCallbackWithFailure<T> {
public final AtomicInteger failures = new AtomicInteger(0);

public void onFailure(InetAddress from)
{
failures.incrementAndGet();
}
}

private TestFailureCallback registerFailureCallback(OutboundTcpConnection.QueuedMessage qm) {
TestFailureCallback cb = new TestFailureCallback();
MessagingService.instance().setCallbackForTests(qm.id, new CallbackInfo(TARGET, cb, mock(IVersionedSerializer.class), true));
return cb;
}
}

0 comments on commit 7266db5

Please sign in to comment.