From 3047289f78274cfc49448168119d8f2fc84bc4bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=99=8E=E9=B8=A3?= Date: Mon, 20 Jan 2025 00:33:53 +0800 Subject: [PATCH] chore: dedicated gears scheduler --- build.sbt | 6 +- jvm/src/main/scala/async/VThreadSupport.scala | 101 ++++++++++++++++-- 2 files changed, 96 insertions(+), 11 deletions(-) diff --git a/build.sbt b/build.sbt index ddfa5a0f..bbb3c620 100644 --- a/build.sbt +++ b/build.sbt @@ -31,7 +31,11 @@ lazy val root = ) .jvmSettings( Seq( - javaOptions += "--version 21" + javaOptions += "--version 21", + Test / javaOptions ++= Seq( + "--add-opens=java.base/java.lang=ALL-UNNAMED", + "--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED" + ) ) ) .nativeSettings( diff --git a/jvm/src/main/scala/async/VThreadSupport.scala b/jvm/src/main/scala/async/VThreadSupport.scala index 8b2dc371..4b01b37f 100644 --- a/jvm/src/main/scala/async/VThreadSupport.scala +++ b/jvm/src/main/scala/async/VThreadSupport.scala @@ -1,15 +1,98 @@ package gears.async -import java.lang.invoke.{MethodHandles, VarHandle} +import java.lang.invoke.{MethodHandles, MethodType} +import java.util.concurrent.TimeUnit.SECONDS +import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.locks.ReentrantLock +import java.util.concurrent.{Executor, ForkJoinPool, ForkJoinWorkerThread, ThreadFactory} import scala.annotation.unchecked.uncheckedVariance import scala.concurrent.duration.FiniteDuration +import scala.util.control.NonFatal object VThreadScheduler extends Scheduler: - private val VTFactory = Thread - .ofVirtual() - .name("gears.async.VThread-", 0L) - .factory() + private val LOOKUP: MethodHandles.Lookup = MethodHandles.lookup() + + private object CarrierThreadFactory extends ForkJoinPool.ForkJoinWorkerThreadFactory { + private val clazz = LOOKUP.findClass("jdk.internal.misc.CarrierThread") + private val constructor = LOOKUP.findConstructor(clazz, MethodType.methodType(classOf[Unit], classOf[ForkJoinPool])) + private val counter = new AtomicLong(0L) + + override def newThread(pool: ForkJoinPool): ForkJoinWorkerThread = { + val t = constructor.invoke(pool).asInstanceOf[ForkJoinWorkerThread] + t.setName("gears-CarrierThread-" + counter.getAndIncrement()) + t + } + } + + private val DEFAULT_SCHEDULER: ForkJoinPool = { + val parallelismValue = sys.props + .get("gears.default-scheduler.parallelism") + .map(_.toInt) + .getOrElse(Runtime.getRuntime.availableProcessors()) + + val maxPoolSizeValue = sys.props + .get("gears.default-scheduler.max-pool-size") + .map(_.toInt) + .getOrElse(256) + + val minRunnableValue = sys.props + .get("gears.default-scheduler.min-runnable") + .map(_.toInt) + .getOrElse(parallelismValue / 2) + + new ForkJoinPool( + parallelismValue, + CarrierThreadFactory, + (t: Thread, e: Throwable) => { + // noop for now + }, + true, + 0, + maxPoolSizeValue, + minRunnableValue, + (pool: ForkJoinPool) => true, + 60, + SECONDS + ) + } + + private val VTFactory = createVirtualThreadFactory("gears", DEFAULT_SCHEDULER) + + /** Create a virtual thread factory with an executor, the executor will be used as the scheduler of virtual thread. + * + * The executor should run task on platform threads. + * + * returns null if not supported. + */ + private def createVirtualThreadFactory(prefix: String, executor: Executor): ThreadFactory = + try { + val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder") + val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual") + val ofVirtualMethod = LOOKUP.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass)) + var builder = ofVirtualMethod.invoke() + if (executor != null) { + val clazz = builder.getClass + val privateLookup = MethodHandles.privateLookupIn( + clazz, + LOOKUP + ) + val schedulerFieldSetter = privateLookup + .findSetter(clazz, "scheduler", classOf[Executor]) + schedulerFieldSetter.invoke(builder, executor) + } + val nameMethod = LOOKUP.findVirtual( + ofVirtualClass, + "name", + MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long]) + ) + val factoryMethod = LOOKUP.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory])) + builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L) + factoryMethod.invoke(builder).asInstanceOf[ThreadFactory] + } catch { + case NonFatal(e) => + // --add-opens java.base/java.lang=ALL-UNNAMED + throw new UnsupportedOperationException("Failed to create virtual thread factory.", e) + } override def execute(body: Runnable): Unit = val th = VTFactory.newThread(body) @@ -31,11 +114,9 @@ object VThreadScheduler extends Scheduler: } private object ScheduledRunnable: - val interruptGuardVar = - MethodHandles - .lookup() - .in(classOf[ScheduledRunnable]) - .findVarHandle(classOf[ScheduledRunnable], "interruptGuard", classOf[Boolean]) + val interruptGuardVar = LOOKUP + .in(classOf[ScheduledRunnable]) + .findVarHandle(classOf[ScheduledRunnable], "interruptGuard", classOf[Boolean]) object VThreadSupport extends AsyncSupport: