Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: dedicated gears scheduler #117

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ lazy val root =
)
.jvmSettings(
Seq(
javaOptions += "--version 21"
javaOptions += "--version 21",
Test / javaOptions ++= Seq(
"--add-opens=java.base/java.lang=ALL-UNNAMED",
"--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't work too :(

)
)
)
.nativeSettings(
Expand Down
101 changes: 91 additions & 10 deletions jvm/src/main/scala/async/VThreadSupport.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,98 @@
package gears.async

import java.lang.invoke.{MethodHandles, VarHandle}
import java.lang.invoke.{MethodHandles, MethodType}
import java.util.concurrent.TimeUnit.SECONDS
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.locks.ReentrantLock
import java.util.concurrent.{Executor, ForkJoinPool, ForkJoinWorkerThread, ThreadFactory}
import scala.annotation.unchecked.uncheckedVariance
import scala.concurrent.duration.FiniteDuration
import scala.util.control.NonFatal

object VThreadScheduler extends Scheduler:
private val VTFactory = Thread
.ofVirtual()
.name("gears.async.VThread-", 0L)
.factory()
private val LOOKUP: MethodHandles.Lookup = MethodHandles.lookup()

private object CarrierThreadFactory extends ForkJoinPool.ForkJoinWorkerThreadFactory {
private val clazz = LOOKUP.findClass("jdk.internal.misc.CarrierThread")
private val constructor = LOOKUP.findConstructor(clazz, MethodType.methodType(classOf[Unit], classOf[ForkJoinPool]))
private val counter = new AtomicLong(0L)

override def newThread(pool: ForkJoinPool): ForkJoinWorkerThread = {
val t = constructor.invoke(pool).asInstanceOf[ForkJoinWorkerThread]
t.setName("gears-CarrierThread-" + counter.getAndIncrement())
t
}
}

private val DEFAULT_SCHEDULER: ForkJoinPool = {
val parallelismValue = sys.props
.get("gears.default-scheduler.parallelism")
.map(_.toInt)
.getOrElse(Runtime.getRuntime.availableProcessors())

val maxPoolSizeValue = sys.props
.get("gears.default-scheduler.max-pool-size")
.map(_.toInt)
.getOrElse(256)

val minRunnableValue = sys.props
.get("gears.default-scheduler.min-runnable")
.map(_.toInt)
.getOrElse(parallelismValue / 2)

new ForkJoinPool(
parallelismValue,
CarrierThreadFactory,
(t: Thread, e: Throwable) => {
// noop for now
},
true,
0,
maxPoolSizeValue,
minRunnableValue,
(pool: ForkJoinPool) => true,
60,
SECONDS
)
}

private val VTFactory = createVirtualThreadFactory("gears", DEFAULT_SCHEDULER)
Copy link
Member

@bishabosha bishabosha Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should consider making this a java static final field - either by moving to a java source, or using the @static annotation (scala.annotation.static) https://docs.scala-lang.org/sips/static-members.html

apparently jvm only inlines method handles if they are static final - there is a comment out there from the hotspot maintainers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, but not sure why ,I can't get it compile with @static and the help message not helpful.

Copy link
Member

@bishabosha bishabosha Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try creating class VThreadScheduler private () - it needs a normal class to put the static field into

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried, but then get an error:

package async

import gears.async.{Cancellable, Scheduler}

import java.lang.invoke.{MethodHandles, MethodType}
import java.util.concurrent.TimeUnit.SECONDS
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.{Executor, ForkJoinPool, ForkJoinWorkerThread, ThreadFactory}
import scala.annotation.static
import scala.concurrent.duration.FiniteDuration
import scala.util.control.NonFatal

object VThreadScheduler extends Scheduler:
  @static
  private val LOOKUP: MethodHandles.Lookup = MethodHandles.lookup()
  @static
  private val DEFAULT_SCHEDULER: ForkJoinPool = {
    val parallelismValue = sys.props
      .get("gears.default-scheduler.parallelism")
      .map(_.toInt)
      .getOrElse(Runtime.getRuntime.availableProcessors())

    val maxPoolSizeValue = sys.props
      .get("gears.default-scheduler.max-pool-size")
      .map(_.toInt)
      .getOrElse(256)

    val minRunnableValue = sys.props
      .get("gears.default-scheduler.min-runnable")
      .map(_.toInt)
      .getOrElse(parallelismValue / 2)

    new ForkJoinPool(
      Runtime.getRuntime.availableProcessors(),
      CarrierThreadFactory,
      (t: Thread, e: Throwable) => {
        // noop for now
      },
      true,
      0,
      maxPoolSizeValue,
      minRunnableValue,
      (pool: ForkJoinPool) => true,
      60,
      SECONDS
    )
  }

  @static
  private val VTFactory = createVirtualThreadFactory("gears", DEFAULT_SCHEDULER)

  private object CarrierThreadFactory extends ForkJoinPool.ForkJoinWorkerThreadFactory {
    private val clazz = LOOKUP.findClass("jdk.internal.misc.CarrierThread")
    private val constructor = LOOKUP.findConstructor(clazz, MethodType.methodType(classOf[Unit], classOf[ForkJoinPool]))
    private val counter = new AtomicLong(0L)

    override def newThread(pool: ForkJoinPool): ForkJoinWorkerThread = {
      val t = constructor.invoke(pool).asInstanceOf[ForkJoinWorkerThread]
      t.setName("gears-CarrierThread-" + counter.getAndIncrement())
      t
    }
  }

  /** Create a virtual thread factory with an executor, the executor will be used as the scheduler of virtual thread.
    *
    * The executor should run task on platform threads.
    *
    * returns null if not supported.
    */
  private def createVirtualThreadFactory(prefix: String, executor: Executor): ThreadFactory =
    try {
      val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
      val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
      val ofVirtualMethod = LOOKUP.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass))
      var builder = ofVirtualMethod.invoke()
      if (executor != null) {
        val clazz = builder.getClass
        val privateLookup = MethodHandles.privateLookupIn(
          clazz,
          LOOKUP
        )
        val schedulerFieldSetter = privateLookup
          .findSetter(clazz, "scheduler", classOf[Executor])
        schedulerFieldSetter.invoke(builder, executor)
      }
      val nameMethod = LOOKUP.findVirtual(
        ofVirtualClass,
        "name",
        MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long])
      )
      val factoryMethod = LOOKUP.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory]))
      builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
      factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
    } catch {
      case NonFatal(e) =>
        // --add-opens java.base/java.lang=ALL-UNNAMED
        throw new UnsupportedOperationException("Failed to create virtual thread factory.", e)
    }

  override def execute(body: Runnable): Unit =
    val th = VTFactory.newThread(body)
    th.start()

  override def schedule(delay: FiniteDuration, body: Runnable): Cancellable = ScheduledRunnable(delay, body)

  private class ScheduledRunnable(val delay: FiniteDuration, val body: Runnable) extends Cancellable {
    @volatile var interruptGuard = true // to avoid interrupting the body

    val th = VTFactory.newThread: () =>
      try Thread.sleep(delay.toMillis)
      catch case e: InterruptedException => () /* we got cancelled, don't propagate */
      if ScheduledRunnable.interruptGuardVar.getAndSet(this, false) then body.run()
    th.start()

    final override def cancel(): Unit =
      if ScheduledRunnable.interruptGuardVar.getAndSet(this, false) then th.interrupt()
  }

  private object ScheduledRunnable:
    val interruptGuardVar = LOOKUP
      .in(classOf[ScheduledRunnable])
      .findVarHandle(classOf[ScheduledRunnable], "interruptGuard", classOf[Boolean])

class VThreadScheduler private ()

the Error is :

  unhandled exception while running MegaPhase{lambdaLift, elimStaticThis, countOuterAccesses} on /Users/hepin/IdeaProjects/gears/jvm/src/main/scala/async/VThreadScheduler.scala

  An unhandled exception was thrown in the compiler.
  Please file a crash report here:
  https://github.com/lampepfl/dotty/issues/new/choose
  For non-enriched exceptions, compile with -Yno-enrich-error-messages.

     while compiling: /Users/hepin/IdeaProjects/gears/jvm/src/main/scala/async/VThreadScheduler.scala
        during phase: MegaPhase{lambdaLift, elimStaticThis, countOuterAccesses}
                mode: Mode(ImplicitsEnabled)
     library version: version 2.13.14
    compiler version: version 3.3.4
            settings: -classpath /Users/hepin/IdeaProjects/gears/jvm/target/scala-3.3.4/classes:/Users/hepin/Library/Caches/Coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala3-library_3/3.3.4/scala3-library_3-3.3.4.jar:/Users/hepin/Library/Caches/Coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala-library/2.13.14/scala-library-2.13.14.jar -d /Users/hepin/IdeaProjects/gears/jvm/target/scala-3.3.4/classes

[error] ## Exception when compiling 24 sources to /Users/hepin/IdeaProjects/gears/jvm/target/scala-3.3.4/classes
[error] java.lang.IllegalArgumentException: Could not find proxy for val maxPoolSizeValue: Int in [value maxPoolSizeValue, value DEFAULT_SCHEDULER, object VThreadScheduler, package async, package <root>], encl = package async, owners = package async, package <root>; enclosures = package async, package <root>
[error] dotty.tools.dotc.transform.LambdaLift$Lifter.searchIn$1(LambdaLift.scala:135)
[error] dotty.tools.dotc.transform.LambdaLift$Lifter.proxy(LambdaLift.scala:148)
[error] dotty.tools.dotc.transform.LambdaLift$Lifter.proxyRef(LambdaLift.scala:166)
[error] dotty.tools.dotc.transform.LambdaLift$Lifter.addFreeArgs$$anonfun$1(LambdaLift.scala:172)
[error] scala.collection.immutable.List.map(List.scala:247)
[error] dotty.tools.dotc.transform.LambdaLift$Lifter.addFreeArgs(LambdaLift.scala:172)
[error] dotty.tools.dotc.transform.LambdaLift.transformApply(LambdaLift.scala:310)
[error] dotty.tools.dotc.transform.LambdaLift.transformApply(LambdaLift.scala:309)
[error] dotty.tools.dotc.transform.MegaPhase.goApply(MegaPhase.scala:675)
[error] dotty.tools.dotc.transform.MegaPhase.transformUnnamed$1(MegaPhase.scala:291)
[error] dotty.tools.dotc.transform.MegaPhase.transformTree(MegaPhase.scala:448)
[error] dotty.tools.dotc.transform.MegaPhase.mapValDef$1(MegaPhase.scala:245)
[error] dotty.tools.dotc.transform.MegaPhase.transformNamed$1(MegaPhase.scala:250)
[error] dotty.tools.dotc.transform.MegaPhase.transformTree(MegaPhase.scala:446)
[error] dotty.tools.dotc.transform.MegaPhase.loop$1(MegaPhase.scala:459)
[error] dotty.tools.dotc.transform.MegaPhase.transformStats(MegaPhase.scala:459)
[error] dotty.tools.dotc.transform.MegaPhase.mapPackage$1(MegaPhase.scala:390)
[error] dotty.tools.dotc.transform.MegaPhase.transformUnnamed$1(MegaPhase.scala:393)
[error] dotty.tools.dotc.transform.MegaPhase.transformTree(MegaPhase.scala:448)
[error] dotty.tools.dotc.transform.MegaPhase.transformUnit(MegaPhase.scala:475)
[error] dotty.tools.dotc.transform.MegaPhase.run(MegaPhase.scala:487)
[error] dotty.tools.dotc.core.Phases$Phase.runOn$$anonfun$1(Phases.scala:336)
[error] scala.runtime.function.JProcedure1.apply(JProcedure1.java:15)
[error] scala.runtime.function.JProcedure1.apply(JProcedure1.java:10)
[error] scala.collection.immutable.List.foreach(List.scala:334)
[error] dotty.tools.dotc.core.Phases$Phase.runOn(Phases.scala:333)
[error] dotty.tools.dotc.Run.runPhases$1$$anonfun$1(Run.scala:315)
[error] scala.runtime.function.JProcedure1.apply(JProcedure1.java:15)
[error] scala.runtime.function.JProcedure1.apply(JProcedure1.java:10)
[error] scala.collection.ArrayOps$.foreach$extension(ArrayOps.scala:1323)
[error] dotty.tools.dotc.Run.runPhases$1(Run.scala:308)
[error] dotty.tools.dotc.Run.compileUnits$$anonfun$1(Run.scala:349)
[error] dotty.tools.dotc.Run.compileUnits$$anonfun$adapted$1(Run.scala:358)
[error] dotty.tools.dotc.util.Stats$.maybeMonitored(Stats.scala:69)
[error] dotty.tools.dotc.Run.compileUnits(Run.scala:358)
[error] dotty.tools.dotc.Run.compileSources(Run.scala:261)
[error] dotty.tools.dotc.Run.compile(Run.scala:246)
[error] dotty.tools.dotc.Driver.doCompile(Driver.scala:37)
[error] dotty.tools.xsbt.CompilerBridgeDriver.run(CompilerBridgeDriver.java:141)
[error] dotty.tools.xsbt.CompilerBridge.run(CompilerBridge.java:22)
[error] sbt.internal.inc.AnalyzingCompiler.compile(AnalyzingCompiler.scala:91)
[error] sbt.internal.inc.MixedAnalyzingCompiler.$anonfun$compile$7(MixedAnalyzingCompiler.scala:196)
[error] scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
[error] sbt.internal.inc.MixedAnalyzingCompiler.timed(MixedAnalyzingCompiler.scala:252)
[error] sbt.internal.inc.MixedAnalyzingCompiler.$anonfun$compile$4(MixedAnalyzingCompiler.scala:186)
[error] sbt.internal.inc.MixedAnalyzingCompiler.$anonfun$compile$4$adapted(MixedAnalyzingCompiler.scala:166)
[error] sbt.internal.inc.JarUtils$.withPreviousJar(JarUtils.scala:241)
[error] sbt.internal.inc.MixedAnalyzingCompiler.compileScala$1(MixedAnalyzingCompiler.scala:166)
[error] sbt.internal.inc.MixedAnalyzingCompiler.compile(MixedAnalyzingCompiler.scala:214)
[error] sbt.internal.inc.IncrementalCompilerImpl.$anonfun$compileInternal$1(IncrementalCompilerImpl.scala:542)
[error] sbt.internal.inc.IncrementalCompilerImpl.$anonfun$compileInternal$1$adapted(IncrementalCompilerImpl.scala:542)
[error] sbt.internal.inc.Incremental$.$anonfun$apply$3(Incremental.scala:178)
[error] sbt.internal.inc.Incremental$.$anonfun$apply$3$adapted(Incremental.scala:176)
[error] sbt.internal.inc.Incremental$$anon$2.run(Incremental.scala:454)
[error] sbt.internal.inc.IncrementalCommon$CycleState.next(IncrementalCommon.scala:117)
[error] sbt.internal.inc.IncrementalCommon$$anon$1.next(IncrementalCommon.scala:56)
[error] sbt.internal.inc.IncrementalCommon$$anon$1.next(IncrementalCommon.scala:52)
[error] sbt.internal.inc.IncrementalCommon.cycle(IncrementalCommon.scala:265)
[error] sbt.internal.inc.Incremental$.$anonfun$incrementalCompile$8(Incremental.scala:409)
[error] sbt.internal.inc.Incremental$.withClassfileManager(Incremental.scala:496)
[error] sbt.internal.inc.Incremental$.incrementalCompile(Incremental.scala:396)
[error] sbt.internal.inc.Incremental$.apply(Incremental.scala:204)
[error] sbt.internal.inc.IncrementalCompilerImpl.compileInternal(IncrementalCompilerImpl.scala:542)
[error] sbt.internal.inc.IncrementalCompilerImpl.$anonfun$compileIncrementally$1(IncrementalCompilerImpl.scala:496)
[error] sbt.internal.inc.IncrementalCompilerImpl.handleCompilationError(IncrementalCompilerImpl.scala:332)
[error] sbt.internal.inc.IncrementalCompilerImpl.compileIncrementally(IncrementalCompilerImpl.scala:433)
[error] sbt.internal.inc.IncrementalCompilerImpl.compile(IncrementalCompilerImpl.scala:137)
[error] sbt.Defaults$.compileIncrementalTaskImpl(Defaults.scala:2419)
[error] sbt.Defaults$.$anonfun$compileIncrementalTask$2(Defaults.scala:2369)
[error] sbt.internal.server.BspCompileTask$.$anonfun$compute$1(BspCompileTask.scala:41)
[error] sbt.internal.io.Retry$.apply(Retry.scala:47)
[error] sbt.internal.io.Retry$.apply(Retry.scala:29)
[error] sbt.internal.io.Retry$.apply(Retry.scala:24)
[error] sbt.internal.server.BspCompileTask$.compute(BspCompileTask.scala:41)
[error] sbt.Defaults$.$anonfun$compileIncrementalTask$1(Defaults.scala:2367)
[error] scala.Function1.$anonfun$compose$1(Function1.scala:49)
[error] sbt.internal.util.$tilde$greater.$anonfun$$u2219$1(TypeFunctions.scala:63)
[error] sbt.std.Transform$$anon$4.work(Transform.scala:69)
[error] sbt.Execute.$anonfun$submit$2(Execute.scala:283)
[error] sbt.internal.util.ErrorHandling$.wideConvert(ErrorHandling.scala:24)
[error] sbt.Execute.work(Execute.scala:292)
[error] sbt.Execute.$anonfun$submit$1(Execute.scala:283)
[error] sbt.ConcurrentRestrictions$$anon$4.$anonfun$submitValid$1(ConcurrentRestrictions.scala:265)
[error] sbt.CompletionService$$anon$2.call(CompletionService.scala:65)
[error] java.base/java.util.concurrent.FutureTask.run(FutureTask.java:317)
[error] java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:572)
[error] java.base/java.util.concurrent.FutureTask.run(FutureTask.java:317)
[error] java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
[error] java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
[error] java.base/java.lang.Thread.run(Thread.java:1583)
[error]            
[error] stack trace is suppressed; run last rootJVM / Compile / compileIncremental for the full output
[error] (rootJVM / Compile / compileIncremental) java.lang.IllegalArgumentException: Could not find proxy for val maxPoolSizeValue: Int in [value maxPoolSizeValue, value DEFAULT_SCHEDULER, object VThreadScheduler, package async, package <root>], encl = package async, owners = package async, package <root>; enclosures = package async, package <root>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works if I make parallelismValue a static field too, I think this is a bug in Scala 3 compiler.


/** Create a virtual thread factory with an executor, the executor will be used as the scheduler of virtual thread.
*
* The executor should run task on platform threads.
*
* returns null if not supported.
*/
private def createVirtualThreadFactory(prefix: String, executor: Executor): ThreadFactory =
try {
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
val ofVirtualMethod = LOOKUP.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass))
var builder = ofVirtualMethod.invoke()
if (executor != null) {
val clazz = builder.getClass
val privateLookup = MethodHandles.privateLookupIn(
clazz,
LOOKUP
)
val schedulerFieldSetter = privateLookup
.findSetter(clazz, "scheduler", classOf[Executor])
schedulerFieldSetter.invoke(builder, executor)
}
val nameMethod = LOOKUP.findVirtual(
ofVirtualClass,
"name",
MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long])
)
val factoryMethod = LOOKUP.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory]))
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
} catch {
case NonFatal(e) =>
// --add-opens java.base/java.lang=ALL-UNNAMED
throw new UnsupportedOperationException("Failed to create virtual thread factory.", e)
}

override def execute(body: Runnable): Unit =
val th = VTFactory.newThread(body)
Expand All @@ -31,11 +114,9 @@ object VThreadScheduler extends Scheduler:
}

private object ScheduledRunnable:
val interruptGuardVar =
MethodHandles
.lookup()
.in(classOf[ScheduledRunnable])
.findVarHandle(classOf[ScheduledRunnable], "interruptGuard", classOf[Boolean])
val interruptGuardVar = LOOKUP
.in(classOf[ScheduledRunnable])
.findVarHandle(classOf[ScheduledRunnable], "interruptGuard", classOf[Boolean])

object VThreadSupport extends AsyncSupport:

Expand Down
Loading