From 24b405bcdbb4a7b1c371c838f02526054bbb3a6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=93lafur=20P=C3=A1ll=20Geirsson?= Date: Tue, 8 Dec 2015 17:00:32 +0100 Subject: [PATCH] Format using shortest path search. This commit adds: - Formatter implementation based on Dijkstra's shortest path search algorithm. This way, we reduce the formatting problem to assigning a weight to each formatting option between any two non-whitespace tokens. - dartfmt inspired testing via FormatTest, see https://github.com/dart-lang/dart_style/tree/2fdfe08b92b0ee90492832df60d1eee9e47afcb5/test/splitting - logging using logback + scala-logging via ScalaFmtLogger. --- build.sbt | 6 + project/build.properties | 2 +- scalafmt/src/main/resources/logback.xml | 21 ++ .../main/scala/org/scalafmt/ScalaFmt.scala | 204 ++++++++++++++++++ .../scala/org/scalafmt/ScalaFmtLogger.scala | 42 ++++ .../main/scala/org/scalafmt/ScalaStyle.scala | 4 + scalafmt/src/test/resources/basic.test | 45 ++++ .../test/scala/org/scalafmt/DiffUtil.scala | 49 +++++ .../test/scala/org/scalafmt/FilesUtil.scala | 17 ++ .../test/scala/org/scalafmt/FormatTest.scala | 40 ++++ .../test/scala/org/scalafmt/ProjectTest.scala | 15 +- 11 files changed, 430 insertions(+), 15 deletions(-) create mode 100644 scalafmt/src/main/resources/logback.xml create mode 100644 scalafmt/src/main/scala/org/scalafmt/ScalaFmt.scala create mode 100644 scalafmt/src/main/scala/org/scalafmt/ScalaFmtLogger.scala create mode 100644 scalafmt/src/main/scala/org/scalafmt/ScalaStyle.scala create mode 100644 scalafmt/src/test/resources/basic.test create mode 100644 scalafmt/src/test/scala/org/scalafmt/DiffUtil.scala create mode 100644 scalafmt/src/test/scala/org/scalafmt/FilesUtil.scala create mode 100644 scalafmt/src/test/scala/org/scalafmt/FormatTest.scala diff --git a/build.sbt b/build.sbt index 5d8f8feea7..5b7dc97a87 100644 --- a/build.sbt +++ b/build.sbt @@ -1,3 +1,5 @@ +triggeredMessage in ThisBuild := Watched.clearWhenTriggered + lazy val scalafmt = project.in(file("scalafmt")).settings( name := "scalafmt", organization := "org.scalafmt", @@ -5,8 +7,12 @@ lazy val scalafmt = project.in(file("scalafmt")).settings( scalaVersion := "2.11.7", resolvers += "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", libraryDependencies ++= Seq( + "com.googlecode.java-diff-utils" % "diffutils" % "1.3.0", + "com.typesafe.scala-logging" %% "scala-logging" % "3.1.0", + "ch.qos.logback" % "logback-classic" % "1.1.3", "org.scalameta" %% "scalameta" % "0.1.0-SNAPSHOT", "org.scalatest" %% "scalatest" % "2.2.1" % "test" ) ) + diff --git a/project/build.properties b/project/build.properties index 748703f770..817bc38df8 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=0.13.7 +sbt.version=0.13.9 diff --git a/scalafmt/src/main/resources/logback.xml b/scalafmt/src/main/resources/logback.xml new file mode 100644 index 0000000000..474e46d8be --- /dev/null +++ b/scalafmt/src/main/resources/logback.xml @@ -0,0 +1,21 @@ + + + true + + [%highlight(%-5level)] %-25(%file:%line) %msg%n + + + + + /tmp/test.log + true + + %d{HH:mm:ss.SSS} TKD [%thread] %-5level %logger{36} - %msg%n + + + + + + + + \ No newline at end of file diff --git a/scalafmt/src/main/scala/org/scalafmt/ScalaFmt.scala b/scalafmt/src/main/scala/org/scalafmt/ScalaFmt.scala new file mode 100644 index 0000000000..c1a3b0c079 --- /dev/null +++ b/scalafmt/src/main/scala/org/scalafmt/ScalaFmt.scala @@ -0,0 +1,204 @@ +package org.scalafmt + +import scala.collection.mutable +import scala.meta._ +import scala.meta.tokens.Token._ + +trait Split + +case object NoSplit extends Split + +case object Space extends Split + +case object Newline extends Split + +/** + * A state represents one potential solution to reach token at index, + * @param cost The penalty for using path + * @param index The index of the current token. + * @param path The splits/decicions made to reach here. + */ +case class State(cost: Int, + index: Int, + path: List[Split]) extends Ordered[State] { + + import scala.math.Ordered.orderingToOrdered + + def compare(that: State): Int = + (-this.cost, this.index) compare(-that.cost, that.index) +} + +class ScalaFmt(style: ScalaStyle) extends ScalaFmtLogger { + + /** + * Pretty-prints Scala code. + */ + def format(code: String): String = { + val source = code.parse[Source] + val realTokens = source.tokens.filter(!_.isInstanceOf[Whitespace]) + val path = shortestPath(source, realTokens) + val sb = new StringBuilder() + realTokens.zip(path).foreach { + case (tok, split) => + sb.append(tok.code) + split match { + case Space => + sb.append(" ") + case Newline => + sb.append("\n") + case NoSplit => + } + } + sb.toString() + } + + /** + * Runs Dijstra's shortest path algorithm to find lowest penalty split. + */ + def shortestPath(source: Source, realTokens: Tokens): List[Split] = { + val owners = getOwners(source) + val Q = new mutable.PriorityQueue[State]() + var explored = 0 + // First state. + Q += State(0, 0, Nil) + while (Q.nonEmpty) { + val curr = Q.dequeue() + explored += 1 + if (explored % 100000 == 0) + println(explored) + val tokens = realTokens + .drop(curr.index) + .dropWhile(_.isInstanceOf[Whitespace]) + val left = tokens.head + if (left.isInstanceOf[EOF]) + return curr.path.reverse + val right = tokens.tail + .find(!_.isInstanceOf[Whitespace]) + .getOrElse(tokens.last) + val between = tokens.drop(1).takeWhile(_.isInstanceOf[Whitespace]) + val splits = splitPenalty(owners, left, between, right) + splits.foreach { + case (split, cost) => + Q.enqueue(State(curr.cost + cost, curr.index + 1, split :: curr.path)) + } + } + // Could not find path to final token. + ??? + } + + /** + * Assigns cost of splitting between two non-whitespace tokens. + */ + def splitPenalty(owners: Map[Token, Tree], + left: Token, + between: Tokens, + right: Token): List[(Split, Int)] = { + (left, right) match { + case (_: BOF, _) => List( + NoSplit -> 0 + ) + case (_, _: EOF) => List( + NoSplit -> 0 + ) + case (_, _) if left.name.startsWith("xml") && + right.name.startsWith("xml") => List( + NoSplit -> 0 + ) + case (_, _: `,`) => List( + NoSplit -> 0 + ) + case (_: `,`, _) => List( + Space -> 0, + Newline -> 1 + ) + case (_: `{`, _) => List( + Space -> 0, + Newline -> 0 + ) + case (_, _: `{`) => List( + Space -> 0 + ) + case (_, _: `}`) => List( + Space -> 0, + Newline -> 1 + ) + case (_, _: `:`) => List( + NoSplit -> 0 + ) + case (_, _: `=`) => List( + Space -> 0 + ) + case (_: `:` | _: `=`, _) => List( + Space -> 0 + ) + case (_, _: `@`) => List( + Newline -> 0 + ) + case (_: `@`, _) => List( + NoSplit -> 0 + ) + case (_: Ident, _: `.` | _: `#`) => List( + NoSplit -> 0 + ) + case (_: `.` | _: `#`, _: Ident) => List( + NoSplit -> 0 + ) + case (_: Ident | _: Literal, _: Ident | _: Literal) => List( + Space -> 0 + ) + case (_, _: `)` | _: `]`) => List( + NoSplit -> 0 + ) + case (_, _: `(` | _: `[`) => List( + NoSplit -> 0 + ) + case (_: `(` | _: `[`, _) => List( + NoSplit -> 0, + Newline -> 1 + ) + case (_, _: `val`) => List( + Space -> 0, + Newline -> 1 + ) + case (_: Keyword | _: Modifier, _) => List( + Space -> 1, + Newline -> 2 + ) + case (_, _: Keyword) => List( + Space -> 0, + Newline -> 1 + ) + case (_, c: Comment) => List( + Space -> 0 + ) + case (c: Comment, _) => + if (c.code.startsWith("//")) List(Newline -> 0) + else List(Space -> 0, Newline -> 1) + case (_, _: Delim) => List( + Space -> 0 + ) + case (_: Delim, _) => List( + Space -> 0 + ) + case _ => + logger.debug(s"60 ===========\n${log(left)}\n${log(between)}\n${log(right)}") + ??? + } + } + + /** + * Creates lookup table from token to its closest scala.meta contains tree. + */ + def getOwners(source: Source): Map[Token, Tree] = { + val result = mutable.Map.empty[Token, Tree] + def loop(x: Tree): Unit = { + x.tokens + .foreach { tok => + result += tok -> x + } + x.children.foreach(loop) + } + loop(source) + result.toMap + } +} diff --git a/scalafmt/src/main/scala/org/scalafmt/ScalaFmtLogger.scala b/scalafmt/src/main/scala/org/scalafmt/ScalaFmtLogger.scala new file mode 100644 index 0000000000..0ed709f65c --- /dev/null +++ b/scalafmt/src/main/scala/org/scalafmt/ScalaFmtLogger.scala @@ -0,0 +1,42 @@ +package org.scalafmt + +import com.typesafe.scalalogging.Logger +import org.slf4j.LoggerFactory + +import scala.meta.Tree +import scala.meta.prettyprinters.Structure +import scala.meta.tokens.Token +import scala.meta.tokens.Tokens + +trait ScalaFmtLogger { + val logger = Logger(LoggerFactory.getLogger(this.getClass)) + + private def getTokenClass(token: Token) = + token.getClass.getName.stripPrefix("scala.meta.tokens.Token$") + + def log(token: Token): String = f"$token%30s ${getTokenClass(token)}" + def log(tokens: Token*): String = tokens.map(log).mkString("\n") + def log(tokens: Tokens): String = tokens.map(log).mkString("\n") + + def header[T](t: T): String = { + val line = s"=" * (t.toString.length + 3) + s"$line\n=> $t\n$line" + } + + def reveal(s: String): String = + s.replaceAll("\n", "¶") + .replaceAll(" ", "∙") + + + + def log(t: Tree, line: Int): Unit = { + logger.debug( + s"""${header(line)} + |TYPE: ${t.getClass.getName.stripPrefix("scala.meta.")} + |SOURCE: $t + |STRUCTURE: ${t.show[Structure]} + |TOKENS: ${t.tokens.map(x => reveal(x.code)).mkString(",")} + |""".stripMargin) + } +} + diff --git a/scalafmt/src/main/scala/org/scalafmt/ScalaStyle.scala b/scalafmt/src/main/scala/org/scalafmt/ScalaStyle.scala new file mode 100644 index 0000000000..e502eda34e --- /dev/null +++ b/scalafmt/src/main/scala/org/scalafmt/ScalaStyle.scala @@ -0,0 +1,4 @@ +package org.scalafmt + +trait ScalaStyle +case object Standard extends ScalaStyle diff --git a/scalafmt/src/test/resources/basic.test b/scalafmt/src/test/resources/basic.test new file mode 100644 index 0000000000..22248c187d --- /dev/null +++ b/scalafmt/src/test/resources/basic.test @@ -0,0 +1,45 @@ +40 columns | +<<< Object definition fits in one line +@foobar object a {val x:Int=1} +>>> +@foobar object a { val x: Int = 1 } +<<< Pathological case +@ foobar("annot", { + val x = 2 + val y = 2 // y=2 + x + y +}) + object + a extends b with c { + def + foo[T:Int#Double#Triple, + R <% String]( + @annot1 + x + : Int @annot2 = 2 + , y: Int = 3): Int = { + "match" match { + case 1 | 2 => + 3 + case 2 => 2 + } + } +} +>>> +@foobar("annot", { + val x = 2 + val y = 2 // y=2 + x + y +}) +object a extends b with c { + def foo[ + T:Int#Double#Triple, + R <% String]( + @annot1 x : Int @annot2 = 2, + y: Int = 3): Int = { + "match" match { + case 1 | 2 => 3 + case 2 => 2 + } + } +} \ No newline at end of file diff --git a/scalafmt/src/test/scala/org/scalafmt/DiffUtil.scala b/scalafmt/src/test/scala/org/scalafmt/DiffUtil.scala new file mode 100644 index 0000000000..b2f22ad881 --- /dev/null +++ b/scalafmt/src/test/scala/org/scalafmt/DiffUtil.scala @@ -0,0 +1,49 @@ +package org.scalafmt + +import java.io.File +import java.text.SimpleDateFormat +import java.util.Date +import java.util.TimeZone + +import org.scalatest.exceptions.TestFailedException + +object DiffUtil extends ScalaFmtLogger { + + implicit class DiffExtension(obtained: String) { + def diff(expected: String): Boolean = { + val result = compareContents(obtained, expected) + if (result.isEmpty) true + else throw new TestFailedException( + s""" + |${header("Obtained")} + |$obtained + | + |${header("Diff")} + |$result + """.stripMargin, 1) + } + } + + def compareContents(original: String, revised: String): String = { + compareContents(original.split("\n"), revised.split("\n")) + } + + def compareContents(original: Seq[String], + revised: Seq[String]): String = { + import collection.JavaConverters._ + val diff = difflib.DiffUtils.diff(original.asJava, revised.asJava) + if (diff.getDeltas.isEmpty) "" + else difflib.DiffUtils.generateUnifiedDiff( + "original", "revised",original.asJava, diff, 1).asScala.drop(3).mkString("\n") + } + + def fileModificationTimeOrEpoch(file: File): String = { + val format = new SimpleDateFormat("yyyy-MM-dd hh:mm:ss Z") + if (file.exists) + format.format(new Date(file.lastModified())) + else { + format.setTimeZone(TimeZone.getTimeZone("UTC")) + format.format(new Date(0L)) + } + } +} diff --git a/scalafmt/src/test/scala/org/scalafmt/FilesUtil.scala b/scalafmt/src/test/scala/org/scalafmt/FilesUtil.scala new file mode 100644 index 0000000000..62364574b3 --- /dev/null +++ b/scalafmt/src/test/scala/org/scalafmt/FilesUtil.scala @@ -0,0 +1,17 @@ +package org.scalafmt + +object FilesUtil { + def listFiles(path: String): Vector[String] = { + def listFilesIter(s: java.io.File): Iterator[String] = { + val (dirs, files) = Option(s.listFiles()).toIterator + .flatMap(_.toIterator) + .partition(_.isDirectory) + files.map(_.getPath) ++ dirs.flatMap(listFilesIter) + } + for { + f0 <- Option(listFilesIter(new java.io.File(path))).toVector + filename <- f0 + } yield filename + } + +} diff --git a/scalafmt/src/test/scala/org/scalafmt/FormatTest.scala b/scalafmt/src/test/scala/org/scalafmt/FormatTest.scala new file mode 100644 index 0000000000..3bb800790b --- /dev/null +++ b/scalafmt/src/test/scala/org/scalafmt/FormatTest.scala @@ -0,0 +1,40 @@ +package org.scalafmt + +import scala.meta._ +import org.scalatest.FunSuite +import org.scalafmt.DiffUtil._ + +case class Test(name: String, original: String, expected: String) + +class FormatTest extends FunSuite with ScalaFmtLogger { + + val fmt = new ScalaFmt(Standard) + + def tests: Seq[Test] = { + import FilesUtil._ + for { + filename <- listFiles( + "scalafmt/src/test/resources") if filename.endsWith(".test") + test <- { + val content = new String( + java.nio.file.Files.readAllBytes(java.nio.file.Paths.get(filename))) + content.split("\n<<< ").tail.map { t => + val before :: expected :: Nil = t.split(">>>\n", 2).toList + val name :: original :: Nil = before.split("\n", 2).toList + Test(name, original, expected) + } + } + } yield { + test + } + } + + tests.foreach { + case Test(name, original, expected) => + test(name) { + assert(fmt.format(original) diff expected) + } + } +} + + diff --git a/scalafmt/src/test/scala/org/scalafmt/ProjectTest.scala b/scalafmt/src/test/scala/org/scalafmt/ProjectTest.scala index 0632b9b02d..af6a6dece7 100644 --- a/scalafmt/src/test/scala/org/scalafmt/ProjectTest.scala +++ b/scalafmt/src/test/scala/org/scalafmt/ProjectTest.scala @@ -51,22 +51,9 @@ class ProjectTest extends FlatSpec { } } - def listFiles(path: String): Vector[String] = { - def listFilesIter(s: java.io.File): Iterator[String] = { - val (dirs, files) = Option(s.listFiles()).toIterator - .flatMap(_.toIterator) - .partition(_.isDirectory) - files.map(_.getPath) ++ dirs.flatMap(listFilesIter) - } - - for { - f0 <- Option(listFilesIter(new java.io.File(path))).toVector - filename <- f0 - } yield filename - } - def checkRepo(url: String, filter: String => Boolean = _ => true) = { import sys.process._ + import FilesUtil._ val name = repoName(url) val path = pathRoot + name println("CLONING?")