Skip to content

Commit

Permalink
Migrate to Scala 3.3.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark1626 committed Mar 4, 2024
1 parent d0480c9 commit 6f30c69
Show file tree
Hide file tree
Showing 46 changed files with 270 additions and 293 deletions.
2 changes: 1 addition & 1 deletion .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
version = "2.4.2"
version = "3.8.0"
align.tokens = []
36 changes: 19 additions & 17 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
name := "Dahlia"
version := "0.0.2"

scalaVersion := "2.13.12"
scalaVersion := "3.3.1"

libraryDependencies ++= Seq(
"org.scalatest" %% "scalatest" % "3.0.8" % "test",
"org.scalatest" %% "scalatest" % "3.2.18" % "test",
"org.scalatest" %% "scalatest-funspec" % "3.2.18" % "test",
"org.scala-lang.modules" %% "scala-parser-combinators" % "2.0.0",
"com.lihaoyi" %% "fastparse" % "2.3.0",
"com.lihaoyi" %% "fastparse" % "3.0.2",
"com.github.scopt" %% "scopt" % "4.0.1",
"com.outr" %% "scribe" % "3.5.5",
"com.lihaoyi" %% "sourcecode" % "0.2.7"
Expand All @@ -16,26 +17,26 @@ scalacOptions ++= Seq(
"-deprecation",
"-unchecked",
"-feature",
"-Ywarn-unused",
"-Ywarn-value-discard",
"-Xfatal-warnings"
"-Xfatal-warnings",
"-new-syntax",
"-indent"
)

// Reload changes to this file.
Global / onChangedBuildSource := ReloadOnSourceChanges

// Disable options in sbt console.
scalacOptions in (Compile, console) ~=
Compile / console / scalacOptions ~=
(_ filterNot ((Set("-Xfatal-warnings", "-Ywarn-unused").contains(_))))

testOptions in Test += Tests.Argument("-oD")
parallelExecution in Test := false
logBuffered in Test := false
Test / testOptions += Tests.Argument("-oD")
Test / parallelExecution := false
Test / logBuffered := false

/* Store commit hash information */
resourceGenerators in Compile += Def.task {
Compile / resourceGenerators += Def.task {
import scala.sys.process._
val file = (resourceManaged in Compile).value / "version.properties"
val file = (Compile / resourceManaged).value / "version.properties"
val gitHash = "git rev-parse --short HEAD".!!
val gitDiff = "git diff --stat".!!
val status = if (gitDiff.trim() != "") "dirty" else "clean"
Expand All @@ -48,11 +49,12 @@ resourceGenerators in Compile += Def.task {
}

/* sbt-assembly configuration: build an executable jar. */
assemblyOption in assembly := (assemblyOption in assembly).value.copy(
prependShellScript = Some(sbtassembly.AssemblyPlugin.defaultShellScript)
)
assemblyJarName in assembly := "fuse.jar"
test in assembly := {}
//assembly / assemblyOption := (assembly / assemblyOption).value.copy(
// prependShellScript = Some(sbtassembly.AssemblyPlugin.defaultShellScript)
//)
ThisBuild / assemblyPrependShellScript := Some(sbtassembly.AssemblyPlugin.defaultShellScript)
assembly / assemblyJarName := "fuse.jar"
assembly / test := {}

/* Define task to download picojson headers */
val getHeaders = taskKey[Unit]("Download header dependencies for runnable backend.")
Expand Down
2 changes: 1 addition & 1 deletion fuse
2 changes: 1 addition & 1 deletion project/assembly.sbt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.9")
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.1.5")
10 changes: 4 additions & 6 deletions src/main/scala/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ object Compiler {
)

def showDebug(ast: Prog, pass: String, c: Config): Unit = {
if (c.passDebug) {
if c.passDebug then {
val top = ("=" * 15) + pass + ("=" * 15)
println(top)
println(Pretty.emitProg(ast)(c.logLevel == scribe.Level.Debug).trim)
Expand All @@ -49,7 +49,7 @@ object Compiler {
showDebug(preAst, "Original", c)

// Run pre transformers if lowering is enabled
val ast = if (c.enableLowering) {
val ast = if c.enableLowering then {
preTransformers.foldLeft(preAst)({
case (ast, (name, pass)) => {
val newAst = pass.rewrite(ast)
Expand Down Expand Up @@ -115,7 +115,7 @@ object Compiler {
err match {
case _: Errors.TypeError => {
s"[${red("Type error")}] ${err.getMessage}" +
(if (c.enableLowering)
(if c.enableLowering then
"\nDoes this program type check without the `--lower` flag? If it does, please report this as a bug: https://github.com/cucapra/dahlia/issues/new"
else "")
}
Expand All @@ -129,9 +129,7 @@ object Compiler {
})
.map(out => {
// Get metadata about the compiler build.
val version = getClass.getResourceAsStream("/version.properties")
val meta = Source
.fromInputStream(version)
val meta = scala.io.Source.fromResource("version.properties")
.getLines()
.filter(l => l.trim != "")
.mkString(", ")
Expand Down
12 changes: 6 additions & 6 deletions src/main/scala/GenerateExec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ object GenerateExec {


// Not the compiler directory, check if the fallback directory has been setup.
if (Files.exists(headerLocation) == false) {
if Files.exists(headerLocation) == false then {
// Fallback for headers not setup. Unpack headers from JAR file.
headerLocation = headerFallbackLocation

if (Files.exists(headerFallbackLocation) == false) {
if Files.exists(headerFallbackLocation) == false then {
scribe.warn(
s"Missing headers required for `fuse run`." +
s" Unpacking from JAR file into $headerFallbackLocation."
)

val dir = Files.createDirectory(headerFallbackLocation)
for (header <- headers) {
for header <- headers do {
val stream = getClass.getResourceAsStream(s"/headers/$header")
val hdrSource = Source.fromInputStream(stream).toArray.map(_.toByte)
Files.write(
Expand All @@ -57,8 +57,8 @@ object GenerateExec {
): Either[String, Int] = {

// Make sure all headers are downloaded.
for (header <- headers) {
if (Files.exists(headerLocation.resolve(header)) == false) {
for header <- headers do {
if Files.exists(headerLocation.resolve(header)) == false then {
throw HeaderMissing(header, headerLocation.toString)
}
}
Expand All @@ -75,7 +75,7 @@ object GenerateExec {
scribe.info(cmd.mkString(" "))
val status = cmd ! logger

if (status != 0) {
if status != 0 then {
Left(s"Failed to generate the executable $out.\n${stderr}")
} else {
Right(status)
Expand Down
31 changes: 10 additions & 21 deletions src/main/scala/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ import Compiler._
import common.Logger
import common.Configuration._

object Main {

object Main:
// Command-line names for backends.
val backends = Map(
"vivado" -> Vivado,
Expand All @@ -23,9 +22,7 @@ object Main {
"axi" -> Axi
)

val version = getClass.getResourceAsStream("/version.properties")
val meta = Source
.fromInputStream(version)
val meta = scala.io.Source.fromResource("version.properties")
.getLines()
.filter(l => l.trim != "")
.map(d => {
Expand Down Expand Up @@ -53,7 +50,7 @@ object Main {
opt[String]('n', "name")
.valueName("<kernel>")
.validate(x =>
if (x.matches("[A-Za-z0-9_]+")) success
if x.matches("[A-Za-z0-9_]+") then success
else failure("Kernel name should only contain alphanumerals and _")
)
.action((x, c) => c.copy(kernelName = x))
Expand All @@ -62,7 +59,7 @@ object Main {
opt[String]('b', "backend")
.valueName("<backend>")
.validate(b =>
if (backends.contains(b)) success
if backends.contains(b) then success
else
failure(
s"Invalid backend name. Valid backends are ${backends.keys.mkString(", ")}"
Expand All @@ -89,7 +86,7 @@ object Main {

opt[String]("memory-interface")
.validate(b =>
if (memoryInterfaces.contains(b)) success
if memoryInterfaces.contains(b) then success
else
failure(
s"Invalid memory interface. Valid memory interfaces are ${memoryInterfaces.keys.mkString(", ")}"
Expand Down Expand Up @@ -117,7 +114,7 @@ object Main {
)
}

def runWithConfig(conf: Config): Either[String, Int] = {
def runWithConfig(conf: Config): Either[String, Int] =
type ErrString = String

val path = conf.srcFile.toPath
Expand Down Expand Up @@ -146,25 +143,17 @@ object Main {
case _ => Right(0)
}
)

status
}

def main(args: Array[String]): Unit = {

parser.parse(args, emptyConf) match {
case Some(conf) => {
def main(args: Array[String]): Unit =
parser.parse(args, emptyConf) match
case Some(conf) =>
Logger.setLogLevel(conf.logLevel)
val status = runWithConfig(conf)
sys.exit(
status.left
.map(compileErr => { System.err.println(compileErr); 1 })
.merge
)
}
case None => {
case None =>
sys.exit(1)
}
}
}
}
26 changes: 8 additions & 18 deletions src/main/scala/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,26 @@ import Configuration.stringToBackend
import Utils.RichOption
import CompilerError.BackendError

case class Parser(input: String) {
case class Parser(input: String):

// Common surround expressions
def braces[K: P, T](p: => P[T]): P[T] = P("{" ~/ p ~ "}")
def brackets[K: P, T](p: => P[T]): P[T] = P("[" ~ p ~ "]")
def angular[K: P, T](p: => P[T]): P[T] = P("<" ~/ p ~ ">")
def parens[K: P, T](p: => P[T]): P[T] = P("(" ~/ p ~ ")")

def positioned[K: P, T <: PositionalWithSpan](p: => P[T]): P[T] = {
def positioned[K: P, T <: PositionalWithSpan](p: => P[T]): P[T] =
P(Index ~ p ~ Index).map({
case (index, t, end) => {
val startPos = OffsetPosition(input, index)
val out = t.setPos(startPos)
val endPos = OffsetPosition(input, end)
if (startPos.line == endPos.line) {
if startPos.line == endPos.line then {
out.setSpan(end - index)
}
out
}
})
}

/*def notKws[K: P] = {
import fastparse.NoWhitespace._
Expand All @@ -43,18 +42,16 @@ case class Parser(input: String) {
) ~ &(" "))).opaque("non reserved keywords")
}*/

def kw[K: P](word: String): P[Unit] = {
def kw[K: P](word: String): P[Unit] =
import fastparse.NoWhitespace._
P(word ~ !CharsWhileIn("a-zA-Z0-9_"))
}

// Basic atoms
def iden[K: P]: P[Id] = {
def iden[K: P]: P[Id] =
import fastparse.NoWhitespace._
positioned(P(CharIn("a-zA-Z_") ~ CharsWhileIn("a-zA-Z0-9_").?).!.map({
case rest => Id(rest)
}).opaque("Expected valid identifier"))
}

def number[K: P]: P[Int] =
P(CharIn("0-9").rep(1).!.map(_.toInt)).opaque("Expected positive number")
Expand Down Expand Up @@ -469,19 +466,12 @@ case class Parser(input: String) {
})
)

def parse(): Prog = {
fastparse.parse[Prog](input, prog(_)) match {
def parse(): Prog =
fastparse.parse[Prog](input, prog(_)) match
case Parsed.Success(e, _) => e
case Parsed.Failure(_, index, extra) => {
case Parsed.Failure(_, index, extra) =>
val traced = extra.trace()
val loc = OffsetPosition(input, index)
val msg = Errors.withPos(s"Expected ${traced.failure.label}", loc)

throw Errors.ParserError(msg)
}
// XXX(rachit): Scala 2.13.4 complains that this pattern is not exhaustive.
// This is not true...
case _ => ???
}
}
}
2 changes: 1 addition & 1 deletion src/main/scala/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ object Utils {
}

@inline def assertOrThrow[T <: Throwable](cond: Boolean, except: => T) = {
if (!cond) throw except
if !cond then throw except
}

@deprecated(
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/backends/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import CompilerError.BackendError
trait Backend {

def emit(p: Syntax.Prog, c: Configuration.Config): String = {
if (c.header && (canGenerateHeader == false)) {
if c.header && (canGenerateHeader == false) then {
throw BackendError(s"Backend $this does not support header generation.")
}
emitProg(p, c)
Expand Down
10 changes: 5 additions & 5 deletions src/main/scala/backends/CppLike.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ object Cpp {
* Helper to generate a function call that might have a type parameter
*/
def cCall(f: String, tParam: Option[Doc], args: Seq[Doc]): Doc = {
text(f) <> (if (tParam.isDefined) angles(tParam.get) else emptyDoc) <>
text(f) <> (if tParam.isDefined then angles(tParam.get) else emptyDoc) <>
parens(commaSep(args))
}

Expand Down Expand Up @@ -76,7 +76,7 @@ object Cpp {
*/
def emitLet(let: CLet): Doc =
emitDecl(let.id, let.typ.get) <>
(if (let.e.isDefined) space <> equal <+> emitExpr(let.e.get)
(if let.e.isDefined then space <> equal <+> emitExpr(let.e.get)
else emptyDoc) <>
semi

Expand All @@ -93,7 +93,7 @@ object Cpp {
case EApp(fn, args) => fn <> parens(commaSep(args.map(emitExpr)))
case EInt(v, base) => value(emitBaseInt(v, base))
case ERational(d) => value(d)
case EBool(b) => value(if (b) 1 else 0)
case EBool(b) => value(if b then 1 else 0)
case EVar(id) => value(id)
case EBinop(op, e1, e2) => parens(e1 <+> text(op.toString) <+> e2)
case EArrAccess(id, idxs) =>
Expand All @@ -116,7 +116,7 @@ object Cpp {
*/
def emitRange(range: CRange): Doc = parens {
val CRange(id, _, rev, s, e, _) = range
if (rev) {
if rev then {
text("int") <+> id <+> equal <+> value(e - 1) <> semi <+>
id <+> text(">=") <+> value(s) <> semi <+>
id <> text("--")
Expand Down Expand Up @@ -167,7 +167,7 @@ object Cpp {
)
.getOrElse(emptyDoc)

if (entry) text("extern") <+> quote(text("C")) <+> scope(body)
if entry then text("extern") <+> quote(text("C")) <+> scope(body)
else body
}

Expand Down
Loading

0 comments on commit 6f30c69

Please sign in to comment.