Skip to content

Commit

Permalink
Rewrite to Scala 3 syntax using -rewrite flag (#423)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark1626 authored Mar 8, 2024
1 parent 410d42b commit 9cbf657
Show file tree
Hide file tree
Showing 48 changed files with 1,041 additions and 2,034 deletions.
3 changes: 0 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ Compile / resourceGenerators += Def.task {
}

/* sbt-assembly configuration: build an executable jar. */
//assembly / assemblyOption := (assembly / assemblyOption).value.copy(
// prependShellScript = Some(sbtassembly.AssemblyPlugin.defaultShellScript)
//)
ThisBuild / assemblyPrependShellScript := Some(sbtassembly.AssemblyPlugin.defaultShellScript)
assembly / assemblyJarName := "fuse.jar"
assembly / test := {}
Expand Down
32 changes: 11 additions & 21 deletions src/main/scala/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import Configuration._
import Syntax._
import Transformer.{PartialTransformer, TypedPartialTransformer}

object Compiler {
object Compiler:

// Transformers to execute *before* type checking.
val preTransformers: List[(String, PartialTransformer)] = List(
Expand All @@ -28,28 +28,25 @@ object Compiler {
"Add bitwidth" -> (passes.AddBitWidth, true)
)

def showDebug(ast: Prog, pass: String, c: Config): Unit = {
if c.passDebug then {
def showDebug(ast: Prog, pass: String, c: Config): Unit =
if c.passDebug then
val top = ("=" * 15) + pass + ("=" * 15)
println(top)
println(Pretty.emitProg(ast)(c.logLevel == scribe.Level.Debug).trim)
println("=" * top.length)
}
}

def toBackend(str: BackendOption): fuselang.backend.Backend = str match {
def toBackend(str: BackendOption): fuselang.backend.Backend = str match
case Vivado => backend.VivadoBackend
case Cpp => backend.CppRunnable
case Calyx => backend.calyx.CalyxBackend
}

def checkStringWithError(prog: String, c: Config = emptyConf) = {
def checkStringWithError(prog: String, c: Config = emptyConf) =
val preAst = Parser(prog).parse()

showDebug(preAst, "Original", c)

// Run pre transformers if lowering is enabled
val ast = if c.enableLowering then {
val ast = if c.enableLowering then
preTransformers.foldLeft(preAst)({
case (ast, (name, pass)) => {
val newAst = pass.rewrite(ast)
Expand All @@ -70,9 +67,8 @@ object Compiler {
} */
}
})
} else {
else
preAst
}
passes.WellFormedChecker.check(ast)
typechecker.TypeChecker.typeCheck(ast);
showDebug(ast, "Type Checking", c)
Expand All @@ -83,9 +79,8 @@ object Compiler {
showDebug(ast, "Capability Checking", c)
typechecker.AffineChecker.check(ast); // Doesn't modify the AST
ast
}

def codegen(ast: Prog, c: Config = emptyConf) = {
def codegen(ast: Prog, c: Config = emptyConf) =
// Filter out transformers not running in this mode
val toRun = postTransformers.filter({
case (_, (_, onlyLower)) => {
Expand All @@ -101,14 +96,12 @@ object Compiler {
}
})
toBackend(c.backend).emit(transformedAst, c)
}

// Outputs red text to the console
def red(txt: String): String = {
def red(txt: String): String =
Console.RED + txt + Console.RESET
}

def compileString(prog: String, c: Config): Either[String, String] = {
def compileString(prog: String, c: Config): Either[String, String] =
Try(codegen(checkStringWithError(prog, c), c)).toEither.left
.map(err => {
scribe.info(err.getStackTrace().take(10).mkString("\n"))
Expand Down Expand Up @@ -136,13 +129,12 @@ object Compiler {
val commentPre = toBackend(c.backend).commentPrefix
s"$commentPre $meta\n" + out
})
}

def compileStringToFile(
prog: String,
c: Config,
out: String
): Either[String, Path] = {
): Either[String, Path] =

compileString(prog, c).map(p => {
Files.write(
Expand All @@ -153,6 +145,4 @@ object Compiler {
StandardOpenOption.WRITE
)
})
}

}
26 changes: 9 additions & 17 deletions src/main/scala/GenerateExec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import common.CompilerError.HeaderMissing
* Provides utilities to compile a program and link it with headers required
* by the CppRunnable backend.
*/
object GenerateExec {
object GenerateExec:
// TODO(rachit): Move this to build.sbt
val headers = List("parser.cpp", "json.hpp")

Expand All @@ -19,18 +19,18 @@ object GenerateExec {


// Not the compiler directory, check if the fallback directory has been setup.
if Files.exists(headerLocation) == false then {
if Files.exists(headerLocation) == false then
// Fallback for headers not setup. Unpack headers from JAR file.
headerLocation = headerFallbackLocation

if Files.exists(headerFallbackLocation) == false then {
if Files.exists(headerFallbackLocation) == false then
scribe.warn(
s"Missing headers required for `fuse run`." +
s" Unpacking from JAR file into $headerFallbackLocation."
)

val dir = Files.createDirectory(headerFallbackLocation)
for header <- headers do {
for header <- headers do
val stream = getClass.getResourceAsStream(s"/headers/$header")
val hdrSource = Source.fromInputStream(stream).toArray.map(_.toByte)
Files.write(
Expand All @@ -39,9 +39,6 @@ object GenerateExec {
StandardOpenOption.CREATE_NEW,
StandardOpenOption.WRITE
)
}
}
}

/**
* Generates an executable object [[out]]. Assumes that [[src]] is a valid
Expand All @@ -54,14 +51,12 @@ object GenerateExec {
src: Path,
out: String,
compilerOpts: List[String]
): Either[String, Int] = {
): Either[String, Int] =

// Make sure all headers are downloaded.
for header <- headers do {
if Files.exists(headerLocation.resolve(header)) == false then {
for header <- headers do
if Files.exists(headerLocation.resolve(header)) == false then
throw HeaderMissing(header, headerLocation.toString)
}
}

val CXX =
Seq("g++", "-g", "--std=c++14", "-Wall", "-I", headerLocation.toString) ++ compilerOpts
Expand All @@ -75,10 +70,7 @@ object GenerateExec {
scribe.info(cmd.mkString(" "))
val status = cmd ! logger

if status != 0 then {
if status != 0 then
Left(s"Failed to generate the executable $out.\n${stderr}")
} else {
else
Right(status)
}
}
}
6 changes: 2 additions & 4 deletions src/main/scala/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ object Main:
})
.toMap

val parser = new scopt.OptionParser[Config]("fuse") {
val parser = new scopt.OptionParser[Config]("fuse"):

head(s"Dahlia (sha = ${meta("git.hash")}, status = ${meta("git.status")})")

Expand Down Expand Up @@ -112,16 +112,14 @@ object Main:
.action((f, c) => c.copy(output = Some(f)))
.text("Name of the output artifact.")
)
}

def runWithConfig(conf: Config): Either[String, Int] =
type ErrString = String

val path = conf.srcFile.toPath
val prog = Files.exists(path) match {
val prog = Files.exists(path) match
case true => Right(new String(Files.readAllBytes(path)))
case false => Left(s"$path: No such file in working directory")
}

val cppPath: Either[ErrString, Option[Path]] = prog.flatMap(prog =>
conf.output match {
Expand Down
39 changes: 13 additions & 26 deletions src/main/scala/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,70 +2,57 @@ package fuselang
import scala.{PartialFunction => PF}
import scala.math.{log10, ceil}

object Utils {
object Utils:

implicit class RichOption[A](opt: => Option[A]) {
def getOrThrow[T <: Throwable](except: T) = opt match {
implicit class RichOption[A](opt: => Option[A]):
def getOrThrow[T <: Throwable](except: T) = opt match
case Some(v) => v
case None => throw except
}
}

// https://codereview.stackexchange.com/questions/14561/matching-bigints-in-scala
// TODO: This can overflow and result in an runtime exception
object Big {
object Big:
def unapply(n: BigInt) = Some(n.toInt)
}

def bitsNeeded(n: Int): Int = n match {
def bitsNeeded(n: Int): Int = n match
case 0 => 1
case n if n > 0 => ceil(log10(n + 1) / log10(2)).toInt
case n if n < 0 => bitsNeeded(n.abs) + 1
}

def bitsNeeded(n: BigInt): Int = n match {
def bitsNeeded(n: BigInt): Int = n match
case Big(0) => 1
case n if n > 0 => ceil(log10((n + 1).toDouble) / log10(2)).toInt
case n if n < 0 => bitsNeeded(n.abs) + 1
}

def cartesianProduct[T](llst: Seq[Seq[T]]): Seq[Seq[T]] = {
def cartesianProduct[T](llst: Seq[Seq[T]]): Seq[Seq[T]] =
def pel(e: T, ll: Seq[Seq[T]], a: Seq[Seq[T]] = Nil): Seq[Seq[T]] =
ll match {
ll match
case Nil => a.reverse
case x +: xs => pel(e, xs, (e +: x) +: a)
}

llst match {
llst match
case Nil => Nil
case x +: Nil => x.map(Seq(_))
case x +: _ =>
x match {
x match
case Nil => Nil
case _ =>
llst
.foldRight(Seq(x))((l, a) => l.flatMap(x => pel(x, a)))
.map(_.dropRight(x.size))
}
}
}


@inline def asPartial[A, B, C](f: (A, B) => C): PF[(A, B), C] = {
@inline def asPartial[A, B, C](f: (A, B) => C): PF[(A, B), C] =
case (a, b) => f(a, b)
}

@inline def assertOrThrow[T <: Throwable](cond: Boolean, except: => T) = {
@inline def assertOrThrow[T <: Throwable](cond: Boolean, except: => T) =
if !cond then throw except
}

@deprecated(
"pr is used for debugging. Remove all call to it before committing",
"fuse 0.0.1"
)
@inline def pr[T](v: T) = {
@inline def pr[T](v: T) =
println(v)
v
}

}
9 changes: 3 additions & 6 deletions src/main/scala/backends/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@ import CompilerError.BackendError
/**
* Abstract definition of a Fuse backend.
*/
trait Backend {
trait Backend:

def emit(p: Syntax.Prog, c: Configuration.Config): String = {
if c.header && (canGenerateHeader == false) then {
def emit(p: Syntax.Prog, c: Configuration.Config): String =
if c.header && (canGenerateHeader == false) then
throw BackendError(s"Backend $this does not support header generation.")
}
emitProg(p, c)
}

/**
* Generate a String representation of the Abstract Syntax Tree of the
Expand All @@ -32,4 +30,3 @@ trait Backend {
*/
val commentPrefix: String = "//"

}
Loading

0 comments on commit 9cbf657

Please sign in to comment.