Skip to content

Commit 24233a8

Browse files
committed
feat: add OwnershipSynchronizer to abstract consumer migration
1 parent 82ca33f commit 24233a8

File tree

1 file changed

+164
-47
lines changed

1 file changed

+164
-47
lines changed

src/main/java/io/lettuce/core/protocol/DefaultAutoBatchFlushEndpoint.java

+164-47
Original file line numberDiff line numberDiff line change
@@ -144,25 +144,25 @@ protected static void cancelCommandOnEndpointClose(RedisCommand<?, ?, ?> cmd) {
144144

145145
private final boolean debugEnabled = logger.isDebugEnabled();
146146

147-
protected final CompletableFuture<Void> closeFuture = new CompletableFuture<>();
147+
private final CompletableFuture<Void> closeFuture = new CompletableFuture<>();
148148

149149
private String logPrefix;
150150

151151
private boolean autoFlushCommands = true;
152152

153153
private boolean inActivation = false;
154154

155-
protected @Nullable ConnectionWatchdog connectionWatchdog;
155+
private @Nullable ConnectionWatchdog connectionWatchdog;
156156

157157
private ConnectionFacade connectionFacade;
158158

159159
private final String cachedEndpointId;
160160

161-
protected final UnboundedOfferFirstQueue<Object> taskQueue;
161+
private final UnboundedOfferFirstQueue<Object> taskQueue;
162162

163-
private final boolean canFire;
163+
private final OwnershipSynchronizer taskQueueConsumeSync; // make sure only one consumer exists at any given time
164164

165-
private volatile EventExecutor lastEventExecutor;
165+
private final boolean canFire;
166166

167167
private volatile Throwable connectionError;
168168

@@ -172,8 +172,6 @@ protected static void cancelCommandOnEndpointClose(RedisCommand<?, ?, ?> cmd) {
172172

173173
private final int batchSize;
174174

175-
private final boolean usesMpscQueue;
176-
177175
/**
178176
* Create a new {@link AutoBatchFlushEndpoint}.
179177
*
@@ -197,13 +195,14 @@ protected DefaultAutoBatchFlushEndpoint(ClientOptions clientOptions, ClientResou
197195
this.rejectCommandsWhileDisconnected = isRejectCommand(clientOptions);
198196
long endpointId = ENDPOINT_COUNTER.incrementAndGet();
199197
this.cachedEndpointId = "0x" + Long.toHexString(endpointId);
200-
this.usesMpscQueue = clientOptions.getAutoBatchFlushOptions().usesMpscQueue();
201-
this.taskQueue = usesMpscQueue ? new JcToolsUnboundedMpscOfferFirstQueue<>() : new ConcurrentLinkedOfferFirstQueue<>();
198+
this.taskQueue = clientOptions.getAutoBatchFlushOptions().usesMpscQueue() ? new JcToolsUnboundedMpscOfferFirstQueue<>()
199+
: new ConcurrentLinkedOfferFirstQueue<>();
202200
this.canFire = false;
203201
this.callbackOnClose = callbackOnClose;
204202
this.writeSpinCount = clientOptions.getAutoBatchFlushOptions().getWriteSpinCount();
205203
this.batchSize = clientOptions.getAutoBatchFlushOptions().getBatchSize();
206-
this.lastEventExecutor = clientResources.eventExecutorGroup().next();
204+
this.taskQueueConsumeSync = new OwnershipSynchronizer(clientResources.eventExecutorGroup().next(),
205+
Thread.currentThread().getName(), true/* allows to be preempted by first event loop thread */);
207206
}
208207

209208
@Override
@@ -324,7 +323,8 @@ public void notifyChannelActive(Channel channel) {
324323
return;
325324
}
326325

327-
this.lastEventExecutor = channel.eventLoop();
326+
this.taskQueueConsumeSync.preempt(channel.eventLoop(), Thread.currentThread().getName(),
327+
false /* disallow preempt until reached quiescent point, see onEndpointQuiescence() */);
328328
this.connectionError = null;
329329
this.inProtectMode = false;
330330
this.logPrefix = null;
@@ -379,7 +379,7 @@ public void notifyReconnectFailed(Throwable t) {
379379
return;
380380
}
381381

382-
syncAfterTerminated(() -> {
382+
taskQueueConsumeSync.execute(() -> {
383383
if (isClosed()) {
384384
onEndpointClosed();
385385
} else {
@@ -474,10 +474,10 @@ public void flushCommands() {
474474
final ContextualChannel chan = this.channel;
475475
switch (chan.context.initialState) {
476476
case ENDPOINT_CLOSED:
477-
syncAfterTerminated(this::onEndpointClosed);
477+
taskQueueConsumeSync.execute(this::onEndpointClosed);
478478
return;
479479
case RECONNECT_FAILED:
480-
syncAfterTerminated(() -> {
480+
taskQueueConsumeSync.execute(() -> {
481481
if (isClosed()) {
482482
onEndpointClosed();
483483
} else {
@@ -563,7 +563,6 @@ public void disconnect() {
563563
*/
564564
@Override
565565
public void reset() {
566-
567566
if (debugEnabled) {
568567
logger.debug("{} reset()", logPrefix());
569568
}
@@ -572,10 +571,7 @@ public void reset() {
572571
if (chan.context.initialState.isConnected()) {
573572
chan.pipeline().fireUserEventTriggered(new ConnectionEvents.Reset());
574573
}
575-
if (!usesMpscQueue) {
576-
cancelCommands("reset");
577-
}
578-
// Otherwise, unsafe to call cancelBufferedCommands() here.
574+
taskQueueConsumeSync.execute(() -> cancelCommands("reset"));
579575
}
580576

581577
private void resetInternal() {
@@ -587,7 +583,6 @@ private void resetInternal() {
587583
if (chan.context.initialState.isConnected()) {
588584
chan.pipeline().fireUserEventTriggered(new ConnectionEvents.Reset());
589585
}
590-
LettuceAssert.assertState(lastEventExecutor.inEventLoop(), "must be called in lastEventLoop thread");
591586
cancelCommands("resetInternal");
592587
}
593588

@@ -596,10 +591,8 @@ private void resetInternal() {
596591
*/
597592
@Override
598593
public void initialState() {
599-
if (!usesMpscQueue) {
600-
cancelCommands("initialState");
601-
}
602-
// Otherwise, unsafe to call cancelBufferedCommands() here.
594+
taskQueueConsumeSync.execute(() -> cancelCommands("initialState"));
595+
603596
ContextualChannel currentChannel = this.channel;
604597
if (currentChannel.context.initialState.isConnected()) {
605598
ChannelFuture close = currentChannel.close();
@@ -637,8 +630,6 @@ public String getId() {
637630
}
638631

639632
private void scheduleSendJobOnConnected(final ContextualChannel chan) {
640-
LettuceAssert.assertState(chan.eventLoop().inEventLoop(), "must be called in event loop thread");
641-
642633
// Schedule directly
643634
loopSend(chan, false);
644635
}
@@ -758,7 +749,6 @@ private int pollBatch(final AutoBatchFlushEndPointContext autoBatchFlushEndPoint
758749
private void trySetEndpointQuiescence(ContextualChannel chan) {
759750
final EventLoop eventLoop = chan.eventLoop();
760751
LettuceAssert.isTrue(eventLoop.inEventLoop(), "unexpected: not in event loop");
761-
LettuceAssert.isTrue(eventLoop == lastEventExecutor, "unexpected: lastEventLoop not match");
762752

763753
final ConnectionContext connectionContext = chan.context;
764754
final @Nullable ConnectionContext.CloseStatus closeStatus = connectionContext.getCloseStatus();
@@ -827,6 +817,8 @@ private void onWontReconnect(@Nonnull final ConnectionContext.CloseStatus closeS
827817
}
828818

829819
private void onEndpointQuiescence() {
820+
taskQueueConsumeSync.done(1); // allows preemption
821+
830822
if (channel.context.initialState == ConnectionContext.State.ENDPOINT_CLOSED) {
831823
return;
832824
}
@@ -864,7 +856,7 @@ private final void onEndpointClosed(Queue<RedisCommand<?, ?, ?>>... queues) {
864856
fulfillCommands("endpoint closed", callbackOnClose, queues);
865857
}
866858

867-
private final void onReconnectFailed() {
859+
private void onReconnectFailed() {
868860
fulfillCommands("reconnect failed", cmd -> cmd.completeExceptionally(getFailedToReconnectReason()));
869861
}
870862

@@ -996,7 +988,7 @@ private Throwable validateWrite(ContextualChannel chan, int commands, boolean is
996988
private void onUnexpectedState(String caller, ConnectionContext.State exp) {
997989
final ConnectionContext.State actual = this.channel.context.initialState;
998990
logger.error("{}[{}][unexpected] : unexpected state: exp '{}' got '{}'", logPrefix(), caller, exp, actual);
999-
syncAfterTerminated(
991+
taskQueueConsumeSync.execute(
1000992
() -> cancelCommands(String.format("%s: state not match: expect '%s', got '%s'", caller, exp, actual)));
1001993
}
1002994

@@ -1017,23 +1009,6 @@ private ChannelFuture channelWrite(Channel channel, RedisCommand<?, ?, ?> comman
10171009
return channel.write(command);
10181010
}
10191011

1020-
/*
1021-
* Synchronize after the endpoint is terminated. This is to ensure only one thread can access the task queue after endpoint
1022-
* is terminated (state is RECONNECT_FAILED/ENDPOINT_CLOSED)
1023-
*/
1024-
private void syncAfterTerminated(Runnable runnable) {
1025-
final EventExecutor localLastEventExecutor = lastEventExecutor;
1026-
if (localLastEventExecutor.inEventLoop()) {
1027-
runnable.run();
1028-
} else {
1029-
localLastEventExecutor.execute(() -> {
1030-
runnable.run();
1031-
LettuceAssert.isTrue(lastEventExecutor == localLastEventExecutor,
1032-
"lastEventLoop must not be changed after terminated");
1033-
});
1034-
}
1035-
}
1036-
10371012
private enum Reliability {
10381013
AT_MOST_ONCE, AT_LEAST_ONCE
10391014
}
@@ -1103,7 +1078,7 @@ public void operationComplete(Future<Void> future) {
11031078

11041079
final Throwable retryableErr = checkSendResult(future);
11051080
if (retryableErr != null && autoBatchFlushEndPointContext.addRetryableFailedToSendCommand(cmd, retryableErr)) {
1106-
// Close connection on first transient write failure
1081+
// Close connection on first transient write failure.
11071082
internalCloseConnectionIfNeeded(retryableErr);
11081083
}
11091084

@@ -1163,6 +1138,7 @@ private void internalCloseConnectionIfNeeded(Throwable reason) {
11631138
return;
11641139
}
11651140

1141+
// It is really rare (maybe impossible?) that the connection is still active.
11661142
logger.error(
11671143
"[internalCloseConnectionIfNeeded][interesting][{}] close the connection due to write error, reason: '{}'",
11681144
endpoint.logPrefix(), reason.getMessage(), reason);
@@ -1184,4 +1160,145 @@ private void recycle() {
11841160

11851161
}
11861162

1163+
public static class OwnershipSynchronizer {
1164+
1165+
private static class Owner {
1166+
1167+
private final EventExecutor thread;
1168+
1169+
private final String threadName;
1170+
1171+
// if positive, no other thread can preempt the ownership.
1172+
private final int semaphore;
1173+
1174+
public Owner(EventExecutor thread, String threadName, int semaphore) {
1175+
LettuceAssert.assertState(semaphore >= 0, () -> String.format("negative semaphore: %d", semaphore));
1176+
this.thread = thread;
1177+
this.threadName = threadName;
1178+
this.semaphore = semaphore;
1179+
}
1180+
1181+
public boolean isCurrentThread() {
1182+
return thread.inEventLoop();
1183+
}
1184+
1185+
public Owner toAdd(int n) {
1186+
return new Owner(thread, threadName, semaphore + n);
1187+
}
1188+
1189+
public Owner toDone(int n) {
1190+
return new Owner(thread, threadName, semaphore - n);
1191+
}
1192+
1193+
public boolean isDone() {
1194+
return semaphore == 0;
1195+
}
1196+
1197+
}
1198+
1199+
private static final AtomicReferenceFieldUpdater<OwnershipSynchronizer, Owner> OWNER = AtomicReferenceFieldUpdater
1200+
.newUpdater(OwnershipSynchronizer.class, Owner.class, "owner");
1201+
1202+
private volatile Owner owner;
1203+
1204+
public OwnershipSynchronizer(EventExecutor thread, String threadName, boolean allowsPreemptByOtherThreads) {
1205+
this.owner = new Owner(thread, threadName, allowsPreemptByOtherThreads ? 0 : 1);
1206+
}
1207+
1208+
/**
1209+
* Preempt ownership only when there is no running tasks in current owner
1210+
*
1211+
* @param thread new thread
1212+
* @param threadName thread name
1213+
* @param allowsPreemptByOtherThreads whether allows a third thread to preempt after @param `thread` preempts from
1214+
* current owner thread, if true, initial running task number will be set to 1.
1215+
*/
1216+
public void preempt(EventExecutor thread, String threadName, boolean allowsPreemptByOtherThreads) {
1217+
Owner cur;
1218+
Owner newOwner = null;
1219+
while (true) {
1220+
cur = this.owner;
1221+
if (cur.thread == thread) {
1222+
if (allowsPreemptByOtherThreads) {
1223+
return;
1224+
}
1225+
if (OWNER.compareAndSet(this, cur, cur.toAdd(1))) { // prevent preempt
1226+
return;
1227+
}
1228+
continue;
1229+
}
1230+
1231+
if (!cur.isDone()) {
1232+
// unsafe to preempt
1233+
continue;
1234+
}
1235+
1236+
if (newOwner == null) {
1237+
newOwner = new Owner(thread, threadName, allowsPreemptByOtherThreads ? 0 : 1);
1238+
}
1239+
if (OWNER.compareAndSet(this, cur, newOwner)) {
1240+
logger.debug("ownership preempted by a new thread [{}]", threadName);
1241+
// established happens-before with done()
1242+
return;
1243+
}
1244+
}
1245+
}
1246+
1247+
/**
1248+
* done n tasks in current owner.
1249+
*
1250+
* @param n number of tasks to be done.
1251+
*/
1252+
public void done(int n) {
1253+
Owner cur;
1254+
do {
1255+
cur = this.owner;
1256+
assertIsOwnerThreadAndPreemptPrevented(cur);
1257+
} while (!OWNER.compareAndSet(this, cur, cur.toDone(n)));
1258+
// create happens-before with preempt()
1259+
}
1260+
1261+
/**
1262+
* Safely run a task in current owner thread and release its memory effect to next owner thread.
1263+
*
1264+
* @param task task to run
1265+
*/
1266+
public void execute(Runnable task) {
1267+
Owner cur;
1268+
do {
1269+
cur = this.owner;
1270+
if (isOwnerCurrentThreadAndPreemptPrevented(cur)) {
1271+
// already prevented preemption, safe to skip expensive add/done calls
1272+
task.run();
1273+
return;
1274+
}
1275+
} while (!OWNER.compareAndSet(this, cur, cur.toAdd(1)));
1276+
1277+
if (cur.isCurrentThread()) {
1278+
executeInOwnerWithPreemptPrevention(task);
1279+
} else {
1280+
cur.thread.execute(() -> executeInOwnerWithPreemptPrevention(task));
1281+
}
1282+
}
1283+
1284+
private void executeInOwnerWithPreemptPrevention(Runnable task) {
1285+
try {
1286+
task.run();
1287+
} finally {
1288+
done(1);
1289+
}
1290+
}
1291+
1292+
private void assertIsOwnerThreadAndPreemptPrevented(Owner cur) {
1293+
LettuceAssert.assertState(isOwnerCurrentThreadAndPreemptPrevented(cur),
1294+
() -> "[executeInOwnerWithPreemptPrevention] unexpected: "
1295+
+ (cur.isCurrentThread() ? "preemption not prevented" : "owner is not this thread"));
1296+
}
1297+
1298+
private boolean isOwnerCurrentThreadAndPreemptPrevented(Owner owner) {
1299+
return owner.isCurrentThread() && !owner.isDone();
1300+
}
1301+
1302+
}
1303+
11871304
}

0 commit comments

Comments
 (0)