Skip to content

Commit

Permalink
chore: dedicated gears scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
He-Pin committed Jan 20, 2025
1 parent 84d00f3 commit 3047289
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 11 deletions.
6 changes: 5 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ lazy val root =
)
.jvmSettings(
Seq(
javaOptions += "--version 21"
javaOptions += "--version 21",
Test / javaOptions ++= Seq(
"--add-opens=java.base/java.lang=ALL-UNNAMED",
"--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED"
)
)
)
.nativeSettings(
Expand Down
101 changes: 91 additions & 10 deletions jvm/src/main/scala/async/VThreadSupport.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,98 @@
package gears.async

import java.lang.invoke.{MethodHandles, VarHandle}
import java.lang.invoke.{MethodHandles, MethodType}
import java.util.concurrent.TimeUnit.SECONDS
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.locks.ReentrantLock
import java.util.concurrent.{Executor, ForkJoinPool, ForkJoinWorkerThread, ThreadFactory}
import scala.annotation.unchecked.uncheckedVariance
import scala.concurrent.duration.FiniteDuration
import scala.util.control.NonFatal

object VThreadScheduler extends Scheduler:
private val VTFactory = Thread
.ofVirtual()
.name("gears.async.VThread-", 0L)
.factory()
private val LOOKUP: MethodHandles.Lookup = MethodHandles.lookup()

private object CarrierThreadFactory extends ForkJoinPool.ForkJoinWorkerThreadFactory {
private val clazz = LOOKUP.findClass("jdk.internal.misc.CarrierThread")
private val constructor = LOOKUP.findConstructor(clazz, MethodType.methodType(classOf[Unit], classOf[ForkJoinPool]))
private val counter = new AtomicLong(0L)

override def newThread(pool: ForkJoinPool): ForkJoinWorkerThread = {
val t = constructor.invoke(pool).asInstanceOf[ForkJoinWorkerThread]
t.setName("gears-CarrierThread-" + counter.getAndIncrement())
t
}
}

private val DEFAULT_SCHEDULER: ForkJoinPool = {
val parallelismValue = sys.props
.get("gears.default-scheduler.parallelism")
.map(_.toInt)
.getOrElse(Runtime.getRuntime.availableProcessors())

val maxPoolSizeValue = sys.props
.get("gears.default-scheduler.max-pool-size")
.map(_.toInt)
.getOrElse(256)

val minRunnableValue = sys.props
.get("gears.default-scheduler.min-runnable")
.map(_.toInt)
.getOrElse(parallelismValue / 2)

new ForkJoinPool(
parallelismValue,
CarrierThreadFactory,
(t: Thread, e: Throwable) => {
// noop for now
},
true,
0,
maxPoolSizeValue,
minRunnableValue,
(pool: ForkJoinPool) => true,
60,
SECONDS
)
}

private val VTFactory = createVirtualThreadFactory("gears", DEFAULT_SCHEDULER)

/** Create a virtual thread factory with an executor, the executor will be used as the scheduler of virtual thread.
*
* The executor should run task on platform threads.
*
* returns null if not supported.
*/
private def createVirtualThreadFactory(prefix: String, executor: Executor): ThreadFactory =
try {
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
val ofVirtualMethod = LOOKUP.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass))
var builder = ofVirtualMethod.invoke()
if (executor != null) {
val clazz = builder.getClass
val privateLookup = MethodHandles.privateLookupIn(
clazz,
LOOKUP
)
val schedulerFieldSetter = privateLookup
.findSetter(clazz, "scheduler", classOf[Executor])
schedulerFieldSetter.invoke(builder, executor)
}
val nameMethod = LOOKUP.findVirtual(
ofVirtualClass,
"name",
MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long])
)
val factoryMethod = LOOKUP.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory]))
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
} catch {
case NonFatal(e) =>
// --add-opens java.base/java.lang=ALL-UNNAMED
throw new UnsupportedOperationException("Failed to create virtual thread factory.", e)
}

override def execute(body: Runnable): Unit =
val th = VTFactory.newThread(body)
Expand All @@ -31,11 +114,9 @@ object VThreadScheduler extends Scheduler:
}

private object ScheduledRunnable:
val interruptGuardVar =
MethodHandles
.lookup()
.in(classOf[ScheduledRunnable])
.findVarHandle(classOf[ScheduledRunnable], "interruptGuard", classOf[Boolean])
val interruptGuardVar = LOOKUP
.in(classOf[ScheduledRunnable])
.findVarHandle(classOf[ScheduledRunnable], "interruptGuard", classOf[Boolean])

object VThreadSupport extends AsyncSupport:

Expand Down

0 comments on commit 3047289

Please sign in to comment.