-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chore: dedicated gears scheduler #117
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,98 @@ | ||
package gears.async | ||
|
||
import java.lang.invoke.{MethodHandles, VarHandle} | ||
import java.lang.invoke.{MethodHandles, MethodType} | ||
import java.util.concurrent.TimeUnit.SECONDS | ||
import java.util.concurrent.atomic.AtomicLong | ||
import java.util.concurrent.locks.ReentrantLock | ||
import java.util.concurrent.{Executor, ForkJoinPool, ForkJoinWorkerThread, ThreadFactory} | ||
import scala.annotation.unchecked.uncheckedVariance | ||
import scala.concurrent.duration.FiniteDuration | ||
import scala.util.control.NonFatal | ||
|
||
object VThreadScheduler extends Scheduler: | ||
private val VTFactory = Thread | ||
.ofVirtual() | ||
.name("gears.async.VThread-", 0L) | ||
.factory() | ||
private val LOOKUP: MethodHandles.Lookup = MethodHandles.lookup() | ||
|
||
private object CarrierThreadFactory extends ForkJoinPool.ForkJoinWorkerThreadFactory { | ||
private val clazz = LOOKUP.findClass("jdk.internal.misc.CarrierThread") | ||
private val constructor = LOOKUP.findConstructor(clazz, MethodType.methodType(classOf[Unit], classOf[ForkJoinPool])) | ||
private val counter = new AtomicLong(0L) | ||
|
||
override def newThread(pool: ForkJoinPool): ForkJoinWorkerThread = { | ||
val t = constructor.invoke(pool).asInstanceOf[ForkJoinWorkerThread] | ||
t.setName("gears-CarrierThread-" + counter.getAndIncrement()) | ||
t | ||
} | ||
} | ||
|
||
private val DEFAULT_SCHEDULER: ForkJoinPool = { | ||
val parallelismValue = sys.props | ||
.get("gears.default-scheduler.parallelism") | ||
.map(_.toInt) | ||
.getOrElse(Runtime.getRuntime.availableProcessors()) | ||
|
||
val maxPoolSizeValue = sys.props | ||
.get("gears.default-scheduler.max-pool-size") | ||
.map(_.toInt) | ||
.getOrElse(256) | ||
|
||
val minRunnableValue = sys.props | ||
.get("gears.default-scheduler.min-runnable") | ||
.map(_.toInt) | ||
.getOrElse(parallelismValue / 2) | ||
|
||
new ForkJoinPool( | ||
parallelismValue, | ||
CarrierThreadFactory, | ||
(t: Thread, e: Throwable) => { | ||
// noop for now | ||
}, | ||
true, | ||
0, | ||
maxPoolSizeValue, | ||
minRunnableValue, | ||
(pool: ForkJoinPool) => true, | ||
60, | ||
SECONDS | ||
) | ||
} | ||
|
||
private val VTFactory = createVirtualThreadFactory("gears", DEFAULT_SCHEDULER) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you should consider making this a java static final field - either by moving to a java source, or using the apparently jvm only inlines method handles if they are static final - there is a comment out there from the hotspot maintainers There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks, but not sure why ,I can't get it compile with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. try creating There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried, but then get an error: package async
import gears.async.{Cancellable, Scheduler}
import java.lang.invoke.{MethodHandles, MethodType}
import java.util.concurrent.TimeUnit.SECONDS
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.{Executor, ForkJoinPool, ForkJoinWorkerThread, ThreadFactory}
import scala.annotation.static
import scala.concurrent.duration.FiniteDuration
import scala.util.control.NonFatal
object VThreadScheduler extends Scheduler:
@static
private val LOOKUP: MethodHandles.Lookup = MethodHandles.lookup()
@static
private val DEFAULT_SCHEDULER: ForkJoinPool = {
val parallelismValue = sys.props
.get("gears.default-scheduler.parallelism")
.map(_.toInt)
.getOrElse(Runtime.getRuntime.availableProcessors())
val maxPoolSizeValue = sys.props
.get("gears.default-scheduler.max-pool-size")
.map(_.toInt)
.getOrElse(256)
val minRunnableValue = sys.props
.get("gears.default-scheduler.min-runnable")
.map(_.toInt)
.getOrElse(parallelismValue / 2)
new ForkJoinPool(
Runtime.getRuntime.availableProcessors(),
CarrierThreadFactory,
(t: Thread, e: Throwable) => {
// noop for now
},
true,
0,
maxPoolSizeValue,
minRunnableValue,
(pool: ForkJoinPool) => true,
60,
SECONDS
)
}
@static
private val VTFactory = createVirtualThreadFactory("gears", DEFAULT_SCHEDULER)
private object CarrierThreadFactory extends ForkJoinPool.ForkJoinWorkerThreadFactory {
private val clazz = LOOKUP.findClass("jdk.internal.misc.CarrierThread")
private val constructor = LOOKUP.findConstructor(clazz, MethodType.methodType(classOf[Unit], classOf[ForkJoinPool]))
private val counter = new AtomicLong(0L)
override def newThread(pool: ForkJoinPool): ForkJoinWorkerThread = {
val t = constructor.invoke(pool).asInstanceOf[ForkJoinWorkerThread]
t.setName("gears-CarrierThread-" + counter.getAndIncrement())
t
}
}
/** Create a virtual thread factory with an executor, the executor will be used as the scheduler of virtual thread.
*
* The executor should run task on platform threads.
*
* returns null if not supported.
*/
private def createVirtualThreadFactory(prefix: String, executor: Executor): ThreadFactory =
try {
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
val ofVirtualMethod = LOOKUP.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass))
var builder = ofVirtualMethod.invoke()
if (executor != null) {
val clazz = builder.getClass
val privateLookup = MethodHandles.privateLookupIn(
clazz,
LOOKUP
)
val schedulerFieldSetter = privateLookup
.findSetter(clazz, "scheduler", classOf[Executor])
schedulerFieldSetter.invoke(builder, executor)
}
val nameMethod = LOOKUP.findVirtual(
ofVirtualClass,
"name",
MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long])
)
val factoryMethod = LOOKUP.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory]))
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
} catch {
case NonFatal(e) =>
// --add-opens java.base/java.lang=ALL-UNNAMED
throw new UnsupportedOperationException("Failed to create virtual thread factory.", e)
}
override def execute(body: Runnable): Unit =
val th = VTFactory.newThread(body)
th.start()
override def schedule(delay: FiniteDuration, body: Runnable): Cancellable = ScheduledRunnable(delay, body)
private class ScheduledRunnable(val delay: FiniteDuration, val body: Runnable) extends Cancellable {
@volatile var interruptGuard = true // to avoid interrupting the body
val th = VTFactory.newThread: () =>
try Thread.sleep(delay.toMillis)
catch case e: InterruptedException => () /* we got cancelled, don't propagate */
if ScheduledRunnable.interruptGuardVar.getAndSet(this, false) then body.run()
th.start()
final override def cancel(): Unit =
if ScheduledRunnable.interruptGuardVar.getAndSet(this, false) then th.interrupt()
}
private object ScheduledRunnable:
val interruptGuardVar = LOOKUP
.in(classOf[ScheduledRunnable])
.findVarHandle(classOf[ScheduledRunnable], "interruptGuard", classOf[Boolean])
class VThreadScheduler private ()
the Error is : unhandled exception while running MegaPhase{lambdaLift, elimStaticThis, countOuterAccesses} on /Users/hepin/IdeaProjects/gears/jvm/src/main/scala/async/VThreadScheduler.scala
An unhandled exception was thrown in the compiler.
Please file a crash report here:
https://github.com/lampepfl/dotty/issues/new/choose
For non-enriched exceptions, compile with -Yno-enrich-error-messages.
while compiling: /Users/hepin/IdeaProjects/gears/jvm/src/main/scala/async/VThreadScheduler.scala
during phase: MegaPhase{lambdaLift, elimStaticThis, countOuterAccesses}
mode: Mode(ImplicitsEnabled)
library version: version 2.13.14
compiler version: version 3.3.4
settings: -classpath /Users/hepin/IdeaProjects/gears/jvm/target/scala-3.3.4/classes:/Users/hepin/Library/Caches/Coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala3-library_3/3.3.4/scala3-library_3-3.3.4.jar:/Users/hepin/Library/Caches/Coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala-library/2.13.14/scala-library-2.13.14.jar -d /Users/hepin/IdeaProjects/gears/jvm/target/scala-3.3.4/classes
[error] ## Exception when compiling 24 sources to /Users/hepin/IdeaProjects/gears/jvm/target/scala-3.3.4/classes
[error] java.lang.IllegalArgumentException: Could not find proxy for val maxPoolSizeValue: Int in [value maxPoolSizeValue, value DEFAULT_SCHEDULER, object VThreadScheduler, package async, package <root>], encl = package async, owners = package async, package <root>; enclosures = package async, package <root>
[error] dotty.tools.dotc.transform.LambdaLift$Lifter.searchIn$1(LambdaLift.scala:135)
[error] dotty.tools.dotc.transform.LambdaLift$Lifter.proxy(LambdaLift.scala:148)
[error] dotty.tools.dotc.transform.LambdaLift$Lifter.proxyRef(LambdaLift.scala:166)
[error] dotty.tools.dotc.transform.LambdaLift$Lifter.addFreeArgs$$anonfun$1(LambdaLift.scala:172)
[error] scala.collection.immutable.List.map(List.scala:247)
[error] dotty.tools.dotc.transform.LambdaLift$Lifter.addFreeArgs(LambdaLift.scala:172)
[error] dotty.tools.dotc.transform.LambdaLift.transformApply(LambdaLift.scala:310)
[error] dotty.tools.dotc.transform.LambdaLift.transformApply(LambdaLift.scala:309)
[error] dotty.tools.dotc.transform.MegaPhase.goApply(MegaPhase.scala:675)
[error] dotty.tools.dotc.transform.MegaPhase.transformUnnamed$1(MegaPhase.scala:291)
[error] dotty.tools.dotc.transform.MegaPhase.transformTree(MegaPhase.scala:448)
[error] dotty.tools.dotc.transform.MegaPhase.mapValDef$1(MegaPhase.scala:245)
[error] dotty.tools.dotc.transform.MegaPhase.transformNamed$1(MegaPhase.scala:250)
[error] dotty.tools.dotc.transform.MegaPhase.transformTree(MegaPhase.scala:446)
[error] dotty.tools.dotc.transform.MegaPhase.loop$1(MegaPhase.scala:459)
[error] dotty.tools.dotc.transform.MegaPhase.transformStats(MegaPhase.scala:459)
[error] dotty.tools.dotc.transform.MegaPhase.mapPackage$1(MegaPhase.scala:390)
[error] dotty.tools.dotc.transform.MegaPhase.transformUnnamed$1(MegaPhase.scala:393)
[error] dotty.tools.dotc.transform.MegaPhase.transformTree(MegaPhase.scala:448)
[error] dotty.tools.dotc.transform.MegaPhase.transformUnit(MegaPhase.scala:475)
[error] dotty.tools.dotc.transform.MegaPhase.run(MegaPhase.scala:487)
[error] dotty.tools.dotc.core.Phases$Phase.runOn$$anonfun$1(Phases.scala:336)
[error] scala.runtime.function.JProcedure1.apply(JProcedure1.java:15)
[error] scala.runtime.function.JProcedure1.apply(JProcedure1.java:10)
[error] scala.collection.immutable.List.foreach(List.scala:334)
[error] dotty.tools.dotc.core.Phases$Phase.runOn(Phases.scala:333)
[error] dotty.tools.dotc.Run.runPhases$1$$anonfun$1(Run.scala:315)
[error] scala.runtime.function.JProcedure1.apply(JProcedure1.java:15)
[error] scala.runtime.function.JProcedure1.apply(JProcedure1.java:10)
[error] scala.collection.ArrayOps$.foreach$extension(ArrayOps.scala:1323)
[error] dotty.tools.dotc.Run.runPhases$1(Run.scala:308)
[error] dotty.tools.dotc.Run.compileUnits$$anonfun$1(Run.scala:349)
[error] dotty.tools.dotc.Run.compileUnits$$anonfun$adapted$1(Run.scala:358)
[error] dotty.tools.dotc.util.Stats$.maybeMonitored(Stats.scala:69)
[error] dotty.tools.dotc.Run.compileUnits(Run.scala:358)
[error] dotty.tools.dotc.Run.compileSources(Run.scala:261)
[error] dotty.tools.dotc.Run.compile(Run.scala:246)
[error] dotty.tools.dotc.Driver.doCompile(Driver.scala:37)
[error] dotty.tools.xsbt.CompilerBridgeDriver.run(CompilerBridgeDriver.java:141)
[error] dotty.tools.xsbt.CompilerBridge.run(CompilerBridge.java:22)
[error] sbt.internal.inc.AnalyzingCompiler.compile(AnalyzingCompiler.scala:91)
[error] sbt.internal.inc.MixedAnalyzingCompiler.$anonfun$compile$7(MixedAnalyzingCompiler.scala:196)
[error] scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
[error] sbt.internal.inc.MixedAnalyzingCompiler.timed(MixedAnalyzingCompiler.scala:252)
[error] sbt.internal.inc.MixedAnalyzingCompiler.$anonfun$compile$4(MixedAnalyzingCompiler.scala:186)
[error] sbt.internal.inc.MixedAnalyzingCompiler.$anonfun$compile$4$adapted(MixedAnalyzingCompiler.scala:166)
[error] sbt.internal.inc.JarUtils$.withPreviousJar(JarUtils.scala:241)
[error] sbt.internal.inc.MixedAnalyzingCompiler.compileScala$1(MixedAnalyzingCompiler.scala:166)
[error] sbt.internal.inc.MixedAnalyzingCompiler.compile(MixedAnalyzingCompiler.scala:214)
[error] sbt.internal.inc.IncrementalCompilerImpl.$anonfun$compileInternal$1(IncrementalCompilerImpl.scala:542)
[error] sbt.internal.inc.IncrementalCompilerImpl.$anonfun$compileInternal$1$adapted(IncrementalCompilerImpl.scala:542)
[error] sbt.internal.inc.Incremental$.$anonfun$apply$3(Incremental.scala:178)
[error] sbt.internal.inc.Incremental$.$anonfun$apply$3$adapted(Incremental.scala:176)
[error] sbt.internal.inc.Incremental$$anon$2.run(Incremental.scala:454)
[error] sbt.internal.inc.IncrementalCommon$CycleState.next(IncrementalCommon.scala:117)
[error] sbt.internal.inc.IncrementalCommon$$anon$1.next(IncrementalCommon.scala:56)
[error] sbt.internal.inc.IncrementalCommon$$anon$1.next(IncrementalCommon.scala:52)
[error] sbt.internal.inc.IncrementalCommon.cycle(IncrementalCommon.scala:265)
[error] sbt.internal.inc.Incremental$.$anonfun$incrementalCompile$8(Incremental.scala:409)
[error] sbt.internal.inc.Incremental$.withClassfileManager(Incremental.scala:496)
[error] sbt.internal.inc.Incremental$.incrementalCompile(Incremental.scala:396)
[error] sbt.internal.inc.Incremental$.apply(Incremental.scala:204)
[error] sbt.internal.inc.IncrementalCompilerImpl.compileInternal(IncrementalCompilerImpl.scala:542)
[error] sbt.internal.inc.IncrementalCompilerImpl.$anonfun$compileIncrementally$1(IncrementalCompilerImpl.scala:496)
[error] sbt.internal.inc.IncrementalCompilerImpl.handleCompilationError(IncrementalCompilerImpl.scala:332)
[error] sbt.internal.inc.IncrementalCompilerImpl.compileIncrementally(IncrementalCompilerImpl.scala:433)
[error] sbt.internal.inc.IncrementalCompilerImpl.compile(IncrementalCompilerImpl.scala:137)
[error] sbt.Defaults$.compileIncrementalTaskImpl(Defaults.scala:2419)
[error] sbt.Defaults$.$anonfun$compileIncrementalTask$2(Defaults.scala:2369)
[error] sbt.internal.server.BspCompileTask$.$anonfun$compute$1(BspCompileTask.scala:41)
[error] sbt.internal.io.Retry$.apply(Retry.scala:47)
[error] sbt.internal.io.Retry$.apply(Retry.scala:29)
[error] sbt.internal.io.Retry$.apply(Retry.scala:24)
[error] sbt.internal.server.BspCompileTask$.compute(BspCompileTask.scala:41)
[error] sbt.Defaults$.$anonfun$compileIncrementalTask$1(Defaults.scala:2367)
[error] scala.Function1.$anonfun$compose$1(Function1.scala:49)
[error] sbt.internal.util.$tilde$greater.$anonfun$$u2219$1(TypeFunctions.scala:63)
[error] sbt.std.Transform$$anon$4.work(Transform.scala:69)
[error] sbt.Execute.$anonfun$submit$2(Execute.scala:283)
[error] sbt.internal.util.ErrorHandling$.wideConvert(ErrorHandling.scala:24)
[error] sbt.Execute.work(Execute.scala:292)
[error] sbt.Execute.$anonfun$submit$1(Execute.scala:283)
[error] sbt.ConcurrentRestrictions$$anon$4.$anonfun$submitValid$1(ConcurrentRestrictions.scala:265)
[error] sbt.CompletionService$$anon$2.call(CompletionService.scala:65)
[error] java.base/java.util.concurrent.FutureTask.run(FutureTask.java:317)
[error] java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:572)
[error] java.base/java.util.concurrent.FutureTask.run(FutureTask.java:317)
[error] java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
[error] java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
[error] java.base/java.lang.Thread.run(Thread.java:1583)
[error]
[error] stack trace is suppressed; run last rootJVM / Compile / compileIncremental for the full output
[error] (rootJVM / Compile / compileIncremental) java.lang.IllegalArgumentException: Could not find proxy for val maxPoolSizeValue: Int in [value maxPoolSizeValue, value DEFAULT_SCHEDULER, object VThreadScheduler, package async, package <root>], encl = package async, owners = package async, package <root>; enclosures = package async, package <root>
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It works if I make |
||
|
||
/** Create a virtual thread factory with an executor, the executor will be used as the scheduler of virtual thread. | ||
* | ||
* The executor should run task on platform threads. | ||
* | ||
* returns null if not supported. | ||
*/ | ||
private def createVirtualThreadFactory(prefix: String, executor: Executor): ThreadFactory = | ||
try { | ||
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder") | ||
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual") | ||
val ofVirtualMethod = LOOKUP.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass)) | ||
var builder = ofVirtualMethod.invoke() | ||
if (executor != null) { | ||
val clazz = builder.getClass | ||
val privateLookup = MethodHandles.privateLookupIn( | ||
clazz, | ||
LOOKUP | ||
) | ||
val schedulerFieldSetter = privateLookup | ||
.findSetter(clazz, "scheduler", classOf[Executor]) | ||
schedulerFieldSetter.invoke(builder, executor) | ||
} | ||
val nameMethod = LOOKUP.findVirtual( | ||
ofVirtualClass, | ||
"name", | ||
MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long]) | ||
) | ||
val factoryMethod = LOOKUP.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory])) | ||
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L) | ||
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory] | ||
} catch { | ||
case NonFatal(e) => | ||
// --add-opens java.base/java.lang=ALL-UNNAMED | ||
throw new UnsupportedOperationException("Failed to create virtual thread factory.", e) | ||
} | ||
|
||
override def execute(body: Runnable): Unit = | ||
val th = VTFactory.newThread(body) | ||
|
@@ -31,11 +114,9 @@ object VThreadScheduler extends Scheduler: | |
} | ||
|
||
private object ScheduledRunnable: | ||
val interruptGuardVar = | ||
MethodHandles | ||
.lookup() | ||
.in(classOf[ScheduledRunnable]) | ||
.findVarHandle(classOf[ScheduledRunnable], "interruptGuard", classOf[Boolean]) | ||
val interruptGuardVar = LOOKUP | ||
.in(classOf[ScheduledRunnable]) | ||
.findVarHandle(classOf[ScheduledRunnable], "interruptGuard", classOf[Boolean]) | ||
|
||
object VThreadSupport extends AsyncSupport: | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't work too :(