Skip to content

Commit

Permalink
Hoist binop (#315)
Browse files Browse the repository at this point in the history
* add a pass that buffers read and writes to memories

* some comments

* initial version of hoisting slow binops

* fixed delays on binary expr groups, added correct delay for sqrt, added Not
guard expr

* fix runt test

* addressed comments

* fixed an incomplete pattern match (why you no error Scala??)

* added position to not implemented

* attempt #1 at fixing JAVA_TOOL_OPTIONS CI failure

* attempt 2
  • Loading branch information
sgpthomas authored Aug 14, 2020
1 parent 21617d2 commit 54fcd00
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 37 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/scala.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,6 @@ jobs:
args: runt --version ${{ steps.versions.outputs.runt }}

- name: Runt tests
run: runt -d -o fail
run: |
unset JAVA_TOOL_OPTIONS
runt -d -o fail
32 changes: 24 additions & 8 deletions file-tests/should-lower/mat.expect
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,30 @@ void kernel(vector<vector<float>> A0_0, vector<vector<float>> A0_1, vector<vecto
B_read0_0_1_1 = B1_1[(unsigned int)k][(unsigned int)j];
B_read0_1_1_1 = B1_1[(unsigned int)k][(unsigned int)j];
//---
float x_0_0_0 = (A_read0_0_0_0 * B_read0_0_0_0);
float x_1_0_0 = (A_read0_1_0_0 * B_read0_1_0_0);
float x_0_1_0 = (A_read0_0_1_0 * B_read0_0_1_0);
float x_1_1_0 = (A_read0_1_1_0 * B_read0_1_1_0);
float x_0_0_1 = (A_read0_0_0_1 * B_read0_0_0_1);
float x_1_0_1 = (A_read0_1_0_1 * B_read0_1_0_1);
float x_0_1_1 = (A_read0_0_1_1 * B_read0_0_1_1);
float x_1_1_1 = (A_read0_1_1_1 * B_read0_1_1_1);
float bin_read0_ = (A_read0_0_0_0 * B_read0_0_0_0);
//---
float x_0_0_0 = bin_read0_;
float bin_read1_ = (A_read0_1_0_0 * B_read0_1_0_0);
//---
float x_1_0_0 = bin_read1_;
float bin_read2_ = (A_read0_0_1_0 * B_read0_0_1_0);
//---
float x_0_1_0 = bin_read2_;
float bin_read3_ = (A_read0_1_1_0 * B_read0_1_1_0);
//---
float x_1_1_0 = bin_read3_;
float bin_read4_ = (A_read0_0_0_1 * B_read0_0_0_1);
//---
float x_0_0_1 = bin_read4_;
float bin_read5_ = (A_read0_1_0_1 * B_read0_1_0_1);
//---
float x_1_0_1 = bin_read5_;
float bin_read6_ = (A_read0_0_1_1 * B_read0_0_1_1);
//---
float x_0_1_1 = bin_read6_;
float bin_read7_ = (A_read0_1_1_1 * B_read0_1_1_1);
//---
float x_1_1_1 = bin_read7_;
//---
C0_0[(unsigned int)i][(unsigned int)j] += (x_0_0_0 + x_0_0_1);
C1_1[(unsigned int)i][(unsigned int)j] += (x_1_1_0 + x_1_1_1);
Expand Down
20 changes: 12 additions & 8 deletions src/main/scala/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@ object Compiler {
"Hoist memory reads" -> passes.HoistMemoryReads,
"Sequentialize" -> passes.Sequentialize,
"Lower unroll and bank" -> passes.LowerUnroll,
"Lower for loops" -> passes.LowerForLoops
"Lower for loops" -> passes.LowerForLoops,
"Hoist slow binops" -> passes.HoistSlowBinop
)

// Transformers to execute *after* type checking. Boolean indicates if the
// pass should only run during lowering.
val postTransformers: List[(String, (TypedPartialTransformer, Boolean))] = List(
"Rewrite views" -> (passes.RewriteView, false),
"Add bitwidth" -> (passes.AddBitWidth, true)
)
val postTransformers: List[(String, (TypedPartialTransformer, Boolean))] =
List(
"Rewrite views" -> (passes.RewriteView, false),
"Add bitwidth" -> (passes.AddBitWidth, true)
)

def showDebug(ast: Prog, pass: String, c: Config): Unit = {
if (c.passDebug) {
Expand Down Expand Up @@ -82,9 +84,11 @@ object Compiler {

def codegen(ast: Prog, c: Config = emptyConf) = {
// Filter out transformers not running in this mode
val toRun = postTransformers.filter({ case (_, (_, onlyLower)) => {
!onlyLower || c.enableLowering
}})
val toRun = postTransformers.filter({
case (_, (_, onlyLower)) => {
!onlyLower || c.enableLowering
}
})
// Run post transformers
val transformedAst = toRun.foldLeft(ast)({
case (ast, (name, (pass, _))) => {
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/GenerateExec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ object GenerateExec {
var headerLocation = Paths.get("src/main/resources/headers")
val headerFallbackLocation = Paths.get("_headers/")


// Not the compiler directory, check if the fallback directory has been setup.
if (Files.exists(headerLocation) == false) {
// Fallback for headers not setup. Unpack headers from JAR file.
Expand Down
19 changes: 14 additions & 5 deletions src/main/scala/backends/futil/FutilAst.scala
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,13 @@ object Futil {
case Atom(item) => item.doc
case And(left, right) => parens(left.doc <+> text("&") <+> right.doc)
case Or(left, right) => parens(left.doc <+> text("|") <+> right.doc)
case Not(inner) => text("!") <> inner.doc
}
}
case class Atom(item: Port) extends GuardExpr
case class And(left: GuardExpr, right: GuardExpr) extends GuardExpr
case class Or(left: GuardExpr, right: GuardExpr) extends GuardExpr
case class Not(inner: GuardExpr) extends GuardExpr

/***** control *****/
sealed trait Control extends Emitable {
Expand Down Expand Up @@ -213,11 +215,11 @@ object Futil {
text("if") <+> port.doc <+> text("with") <+>
cond.doc <+>
scope(trueBr.doc) <> (
if (falseBr == Empty)
emptyDoc
else
space <> text("else") <+> scope(falseBr.doc)
)
if (falseBr == Empty)
emptyDoc
else
space <> text("else") <+> scope(falseBr.doc)
)
case While(port, cond, body) =>
text("while") <+> port.doc <+> text("with") <+>
cond.doc <+>
Expand Down Expand Up @@ -315,4 +317,11 @@ object Stdlib {

def sqrt(): Futil.CompInst =
Futil.CompInst("std_sqrt", List())

val staticTimingMap: Map[String, Option[Int]] = Map(
"sqrt" -> Some(17),
"mult" -> Some(3),
"div" -> None,
"mod" -> None
)
}
95 changes: 80 additions & 15 deletions src/main/scala/backends/futil/FutilBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ private class FutilBackendHelper {
def bitsForType(t: Option[Type], pos: Position): Int = {
t match {
case Some(TSizedInt(width, _)) => width
case Some(_:TBool) => 1
case Some(_:TVoid) => 0
case Some(_: TBool) => 1
case Some(_: TVoid) => 0
case x =>
throw NotImplemented(
s"Futil cannot infer bitwidth for type $x. Please manually annotate it using a cast expression.",
Expand Down Expand Up @@ -188,7 +188,50 @@ private class FutilBackendHelper {
comp.id.port("out"),
ConstantPort(1, 1),
struct ++ e1Out.structure ++ e2Out.structure,
Some(0)
for (d1 <- e1Out.delay; d2 <- e2Out.delay)
yield d1 + d2
)
}

def emitMultiCycleBinop(
compName: String,
e1: Expr,
e2: Expr,
delay: Option[Int]
)(
implicit store: Store
): EmitOutput = {
val e1Out = emitExpr(e1)
val e2Out = emitExpr(e2)
val e1Bits = bitsForType(e1.typ, e1.pos)
val e2Bits = bitsForType(e2.typ, e2.pos)
assertOrThrow(
e1Bits == e2Bits,
Impossible(
"The widths of the left and right side of a binop didn't match." +
s"\nleft: ${Pretty.emitExpr(e1)(false).pretty}: ${e1Bits}" +
s"\nright: ${Pretty.emitExpr(e2)(false).pretty}: ${e2Bits}"
)
)
val binop = Stdlib.op(s"$compName", bitsForType(e1.typ, e1.pos));

val comp = LibDecl(genName(compName), binop)
val struct = List(
comp,
Connect(e1Out.port, comp.id.port("left")),
Connect(e2Out.port, comp.id.port("right")),
Connect(
ConstantPort(1, 1),
comp.id.port("go"),
Some(Not(Atom(comp.id.port("done"))))
)
)
EmitOutput(
comp.id.port("out"),
comp.id.port("done"),
struct ++ e1Out.structure ++ e2Out.structure,
for (d1 <- e1Out.delay; d2 <- e2Out.delay; d3 <- delay)
yield d1 + d2 + d3
)
}

Expand All @@ -203,23 +246,26 @@ private class FutilBackendHelper {
implicit store: Store
): EmitOutput =
expr match {
case _:EInt => {
throw PassError("Cannot compile unannotated constants. Wrap constant in `as` expression", expr.pos)
case _: EInt => {
throw PassError(
"Cannot compile unannotated constants. Wrap constant in `as` expression",
expr.pos
)
}
case EBinop(op, e1, e2) => {
val compName =
op.op match {
case "+" => "add"
case "-" => "sub"
case "*" => "mult"
case "/" => "div"
case "*" => "mult_pipe"
case "/" => "div_pipe"
case "<" => "lt"
case ">" => "gt"
case "<=" => "le"
case ">=" => "ge"
case "!=" => "neq"
case "==" => "eq"
case "%" => "mod"
case "%" => "mod_pipe"
case "&&" => "and"
case "||" => "or"
case "&" => "and"
Expand All @@ -232,7 +278,20 @@ private class FutilBackendHelper {
op.pos
)
}
emitBinop(compName, e1, e2)
op.op match {
case "*" =>
emitMultiCycleBinop(
compName,
e1,
e2,
Stdlib.staticTimingMap("mult")
)
case "/" =>
emitMultiCycleBinop(compName, e1, e2, Stdlib.staticTimingMap("div"))
case "%" =>
emitMultiCycleBinop(compName, e1, e2, Stdlib.staticTimingMap("mod"))
case _ => emitBinop(compName, e1, e2)
}
}
case EVar(id) =>
val portName = if (rhsInfo.isDefined) "in" else "out"
Expand Down Expand Up @@ -274,9 +333,12 @@ private class FutilBackendHelper {
}
case ECast(e, t) => {
val vBits = bitsForType(e.typ, e.pos)
val cBits = bitsForType (Some(t), e.pos)
val cBits = bitsForType(Some(t), e.pos)
if (cBits > vBits) {
throw NotImplemented("Cast expressions that imply zero-padding", expr.pos)
throw NotImplemented(
"Cast expressions that imply zero-padding",
expr.pos
)
}
val res = emitExpr(e)
val sliceOp =
Expand Down Expand Up @@ -339,13 +401,17 @@ private class FutilBackendHelper {
val struct = List(
sqrt,
Connect(argOut.port, sqrt.id.port("in")),
Connect(ConstantPort(1, 1), sqrt.id.port("go"))
Connect(
ConstantPort(1, 1),
sqrt.id.port("go"),
Some(Not(Atom(sqrt.id.port("done"))))
)
)
EmitOutput(
sqrt.id.port("out"),
sqrt.id.port("done"),
argOut.structure ++ struct,
Some(1)
Stdlib.staticTimingMap("sqrt")
)
}
case x =>
Expand All @@ -355,7 +421,6 @@ private class FutilBackendHelper {
def emitCmd(
c: Command
)(implicit store: Store): (List[Structure], Control, Store) = {
//println(Pretty.emitCmd(c)(false).pretty)
c match {
case CBlock(cmd) => emitCmd(cmd)
case CPar(cmds) => {
Expand Down Expand Up @@ -395,7 +460,7 @@ private class FutilBackendHelper {
)
val struct =
Connect(out.port, reg.id.port("in")) :: Connect(
ConstantPort(1, 1),
out.done,
reg.id.port("write_en")
) :: doneHole :: out.structure
val (group, st) =
Expand Down
Loading

0 comments on commit 54fcd00

Please sign in to comment.