Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Codec#product to behave correctly with parameter count #94

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions core/js/src/main/scala/porcupine/dbplatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ private abstract class DatabasePlatform:
def cursor(args: A): Resource[F, Cursor[F, B]] = mutex.lock *>
Resource
.eval {
F.delay(statement.iterate(bind(args)*)).map { iterator =>
val argsList = bind(args)
val argsDict: js.Dictionary[Any] =
js.Dictionary(argsList.zipWithIndex.map { case (v, idx) =>
(s"${idx + 1}", v)
}*)
F.delay(statement.iterate(argsDict)).map { iterator =>
Comment on lines +55 to +60
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put this all in the delay block, and I think we can build the JS object more efficiently without all the intermediate datastructures form zipWithIndex and map. Instead, we can just create a new js.Object and set its keys with js.Dynamic.

new:
def fetch(maxRows: Int): F[(List[B], Boolean)] =
F.delay {
Expand Down Expand Up @@ -90,7 +95,12 @@ private abstract class DatabasePlatform:
new AbstractStatement[F, A, B]:
def cursor(args: A): Resource[F, Cursor[F, B]] =
mutex.lock *> Resource.eval {
F.delay(statement.run(bind(args)*)).as(_ => F.pure(Nil, false))
val argsList = bind(args)
val argsDict: js.Dictionary[Any] =
js.Dictionary(argsList.zipWithIndex.map { case (v, idx) =>
(s"${idx + 1}", v)
}*)
F.delay(statement.run(argsDict)).as(_ => F.pure(Nil, false))
}

}
Expand Down
5 changes: 3 additions & 2 deletions core/js/src/main/scala/porcupine/facade.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ private[porcupine] trait Statement extends js.Object:

def raw(toggleState: Boolean): Statement = js.native

def iterate(bindParameters: Any*): js.Iterator[js.UndefOr[js.Array[Any]]] = js.native
def iterate(bindParameters: js.Dictionary[Any]): js.Iterator[js.UndefOr[js.Array[Any]]] =
js.native

def run(bindParameters: Any*): js.Object = js.native
def run(bindParameters: js.Dictionary[Any]): js.Object = js.native
105 changes: 63 additions & 42 deletions core/shared/src/main/scala/porcupine/codec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,48 @@ package porcupine
import cats.Applicative
import cats.ContravariantMonoidal
import cats.InvariantMonoidal
import cats.data.StateT
import cats.data.{State, StateT}
import cats.syntax.all.*
import scodec.bits.ByteVector

import scala.deriving.Mirror

trait Encoder[A]:
outer =>

def parameters: State[Int, List[Int]]

def encode(a: A): List[LiteValue]

def either[B](right: Encoder[B]): Encoder[Either[A, B]] = new:
def encode(aorb: Either[A, B]) = aorb match
case Left(a) => outer.encode(a)
case Right(b) => right.encode(b)
// def either[B](right: Encoder[B]): Encoder[Either[A, B]] = new:
// TODO figure out if this is reasonably implementable
// def parameters: Int = ???
//
// def encode(aorb: Either[A, B]) = aorb match
// case Left(a) => outer.encode(a)
// case Right(b) => right.encode(b)

def opt: Encoder[Option[A]] =
either(Codec.`null`).contramap(_.toLeft(None))
// either(Codec.`null`).contramap(_.toLeft(None))
new:
def parameters = outer.parameters
def encode(aopt: Option[A]) = aopt match
case None => Codec.`null`.encode(None)
case Some(a) => outer.encode(a)

object Encoder:
given ContravariantMonoidal[Encoder] = new:
def unit = Codec.unit

def product[A, B](fa: Encoder[A], fb: Encoder[B]) = new:
def parameters =
(fa.parameters, fb.parameters).mapN(_ ++ _)

def encode(ab: (A, B)) =
val (a, b) = ab
fa.encode(a) ::: fb.encode(b)

def contramap[A, B](fa: Encoder[A])(f: B => A) = new:
def parameters = fa.parameters
def encode(b: B) = fa.encode(f(b))

trait Decoder[A]:
Expand Down Expand Up @@ -92,52 +105,56 @@ trait Codec[A] extends Encoder[A], Decoder[A]:
def asEncoder: Encoder[A] = this
def asDecoder: Decoder[A] = this

def either[B](right: Codec[B]): Codec[Either[A, B]] = new:
def encode(aorb: Either[A, B]) =
outer.asEncoder.either(right).encode(aorb)
// def either[B](right: Codec[B]): Codec[Either[A, B]] = new:
// def parameters: State[Int, String] =
// outer.asEncoder.either(right).parameters
//
// def encode(aorb: Either[A, B]) =
// outer.asEncoder.either(right).encode(aorb)
//
// def decode = outer.asDecoder.either(right).decode

def decode = outer.asDecoder.either(right).decode

override def opt: Codec[Option[A]] =
either(Codec.`null`).imap(_.left.toOption)(_.toLeft(None))
override def opt: Codec[Option[A]] = new:
def parameters = outer.parameters
def encode(aopt: Option[A]) = outer.asEncoder.opt.encode(aopt)
def decode = outer.asDecoder.opt.decode

object Codec:
val integer: Codec[Long] = new:
def encode(l: Long) = LiteValue.Integer(l) :: Nil
def decode = StateT {
case LiteValue.Integer(l) :: tail => Right((tail, l))
case other => Left(new RuntimeException(s"Expected integer, got ${other.headOption}"))
extension [H](head: Codec[H])
def *:[T <: Tuple](tail: Codec[T]): Codec[H *: T] = (head, tail).imapN(_ *: _) {
case h *: t => (h, t)
}

val real: Codec[Double] = new:
def encode(d: Double) = LiteValue.Real(d) :: Nil
def decode = StateT {
case LiteValue.Real(d) :: tail => Right((tail, d))
case other => Left(new RuntimeException(s"Expected real, got ${other.headOption}"))
private final class Simple[T](
name: String,
apply: T => LiteValue,
unapply: PartialFunction[LiteValue, T],
) extends Codec[T] {
override def parameters: State[Int, List[Int]] = State(idx => (idx + 1, List(idx)))
override def encode(a: T): List[LiteValue] = apply(a) :: Nil
override def decode: StateT[Either[Throwable, *], List[LiteValue], T] = StateT {
case unapply(l) :: tail => Right((tail, l))
case other => Left(new RuntimeException(s"Expected $name, got ${other.headOption}"))
}
}

val text: Codec[String] = new:
def encode(s: String) = LiteValue.Text(s) :: Nil
def decode = StateT {
case LiteValue.Text(s) :: tail => Right((tail, s))
case other => Left(new RuntimeException(s"Expected text, got ${other.headOption}"))
}
val integer: Codec[Long] =
new Simple("integer", LiteValue.Integer.apply, { case LiteValue.Integer(i) => i })

val blob: Codec[ByteVector] = new:
def encode(b: ByteVector) = LiteValue.Blob(b) :: Nil
def decode = StateT {
case LiteValue.Blob(b) :: tail => Right((tail, b))
case other => Left(new RuntimeException(s"Expected blob, got ${other.headOption}"))
}
val real: Codec[Double] =
new Simple("real", LiteValue.Real.apply, { case LiteValue.Real(r) => r })

val `null`: Codec[None.type] = new:
def encode(n: None.type) = LiteValue.Null :: Nil
def decode = StateT {
case LiteValue.Null :: tail => Right((tail, None))
case other => Left(new RuntimeException(s"Expected NULL, got ${other.headOption}"))
}
val text: Codec[String] =
new Simple("text", LiteValue.Text.apply, { case LiteValue.Text(t) => t })

val blob: Codec[ByteVector] =
new Simple("blob", LiteValue.Blob.apply, { case LiteValue.Blob(b) => b })

val `null`: Codec[None.type] =
new Simple("NULL", _ => LiteValue.Null, { case LiteValue.Null => None })

def unit: Codec[Unit] = new:
def parameters = State.pure(List.empty)
def encode(u: Unit) = Nil
def decode = StateT.pure(())

Expand All @@ -147,12 +164,16 @@ object Codec:
def unit = Codec.unit

def product[A, B](fa: Codec[A], fb: Codec[B]) = new:
def parameters =
(fa.parameters, fb.parameters).mapN(_ ++ _)

def encode(ab: (A, B)) =
val (a, b) = ab
fa.encode(a) ::: fb.encode(b)

def decode = fa.decode.product(fb.decode)

def imap[A, B](fa: Codec[A])(f: A => B)(g: B => A) = new:
def parameters = fa.parameters
def encode(b: B) = fa.encode(g(b))
def decode = fa.decode.map(f)
77 changes: 60 additions & 17 deletions core/shared/src/main/scala/porcupine/sql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ package porcupine
import cats.ContravariantMonoidal
import cats.Monoid
import cats.arrow.Profunctor
import cats.data.State
import cats.syntax.all.*

import scala.quoted.Expr
import scala.quoted.Exprs
import scala.quoted.Quotes
Expand All @@ -34,24 +34,62 @@ object Query:
def dimap[A, B, C, D](fab: Query[A, B])(f: C => A)(g: B => D) =
Query(fab.sql, fab.encoder.contramap(f), fab.decoder.map(g))

final class Fragment[A](val fragment: String, val encoder: Encoder[A]):
def command: Query[A, Unit] = Query(fragment, encoder, Codec.unit)
final class Fragment[A](
val part: Fragment.Part,
val encoder: Encoder[A],
):
def sql: String = part.compile.runA(1).value

def command: Query[A, Unit] = Query(sql, encoder, Codec.unit)

def query[B](decoder: Decoder[B]): Query[A, B] = Query(sql, encoder, decoder)

def query[B](decoder: Decoder[B]): Query[A, B] = Query(fragment, encoder, decoder)
def apply(a: A): Fragment[Unit] = Fragment(part, encoder.contramap(_ => a))

def apply(a: A): Fragment[Unit] = Fragment(fragment, encoder.contramap(_ => a))
def stripMargin: Fragment[A] = stripMargin('|')

def stripMargin: Fragment[A] = Fragment(fragment.stripMargin, encoder)
def stripMargin(marginChar: Char): Fragment[A] =
Fragment(fragment.stripMargin(marginChar), encoder)
Fragment(part.stripMargin(true, marginChar), encoder)

object Fragment:
sealed trait Part:
def compile: State[Int, String]
def concatenate(other: Part): Part = other match {
case Part.Concatenate(values) => Part.Concatenate(this :: values)
case _ => Part.Concatenate(List(this, other))
}
def stripMargin(head: Boolean, marginChar: Char): Part

object Part:
final case class Literal(x: String) extends Part:
def compile = State.pure(x)
def stripMargin(head: Boolean, marginChar: Char) =
if (head) Literal(x.stripMargin(marginChar))
else Literal(x.takeWhile(_ != '\n') ++ x.dropWhile(_ != '\n').stripMargin(marginChar))
final case class Concatenate(values: List[Part]) extends Part:
def compile = values.traverse(_.compile).map(_.combineAll)
override def concatenate(other: Part) = other match {
case Concatenate(values) => Concatenate(this.values ++ values)
case _ => Concatenate(this.values :+ other)
}
def stripMargin(head: Boolean, marginChar: Char) =
values match {
case h :: t =>
Concatenate(
h.stripMargin(head, marginChar) :: t.map(_.stripMargin(false, marginChar)),
)
case other => this
}
final case class Parameters(advance: State[Int, List[Int]]) extends Part:
def compile = advance.map(_.map(idx => s"?$idx").mkString(", "))
def stripMargin(head: Boolean, marginChar: Char) = this

given ContravariantMonoidal[Fragment] = new:
val unit = Fragment("", Codec.unit)
val unit = Fragment(Part.Concatenate(List.empty), Codec.unit)
def product[A, B](fa: Fragment[A], fb: Fragment[B]) =
Fragment(fa.fragment + fb.fragment, (fa.encoder, fb.encoder).tupled)
Fragment(fa.part.concatenate(fb.part), (fa.encoder, fb.encoder).tupled)
def contramap[A, B](fa: Fragment[A])(f: B => A) =
Fragment(fa.fragment, fa.encoder.contramap(f))
Fragment(fa.part, fa.encoder.contramap(f))

given Monoid[Fragment[Unit]] = new:
def empty = ContravariantMonoidal[Fragment].unit
Expand All @@ -73,12 +111,16 @@ private def sqlImpl(

val args = Varargs.unapply(argsExpr).toList.flatMap(_.toList)

val fragment = parts.zipAll(args, '{ "" }, '{ "" }).foldLeft('{ "" }) {
case ('{ $acc: String }, ('{ $p: String }, '{ $s: String })) => '{ $acc + $p + $s }
case ('{ $acc: String }, ('{ $p: String }, '{ $e: Encoder[t] })) => '{ $acc + $p + "?" }
case ('{ $acc: String }, ('{ $p: String }, '{ $f: Fragment[t] })) =>
'{ $acc + $p + $f.fragment }
}
// TODO appending to `List` is slow
val fragment =
parts.zipAll(args, '{ "" }, '{ "" }).foldLeft('{ List.empty[Fragment.Part] }) {
case ('{ $acc: List[Fragment.Part] }, ('{ $p: String }, '{ $s: String })) =>
'{ $acc :+ Fragment.Part.Literal($p) :+ Fragment.Part.Literal($s) }
case ('{ $acc: List[Fragment.Part] }, ('{ $p: String }, '{ $e: Encoder[t] })) =>
'{ $acc :+ Fragment.Part.Literal($p) :+ Fragment.Part.Parameters($e.parameters) }
case ('{ $acc: List[Fragment.Part] }, ('{ $p: String }, '{ $f: Fragment[t] })) =>
'{ $acc :+ Fragment.Part.Literal($p) :+ $f.part }
}

val encoder = args.collect {
case '{ $e: Encoder[t] } => e
Expand All @@ -103,5 +145,6 @@ private def sqlImpl(
}

(fragment, encoder) match
case ('{ $s: String }, '{ $e: Encoder[a] }) => '{ Fragment[a]($s, $e) }
case ('{ $s: List[Fragment.Part] }, '{ $e: Encoder[a] }) =>
'{ Fragment[a](Fragment.Part.Concatenate($s), $e) }
case _ => sys.error("porcupine pricked itself")
6 changes: 3 additions & 3 deletions core/shared/src/test/scala/porcupine/PorcupineTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ package porcupine

import cats.effect.IOApp
import cats.effect.IO
import cats.syntax.all.*
import scodec.bits.ByteVector

import Codec.*

object PorcupineTest extends IOApp.Simple:

def run = Database.open[IO](":memory:").use { db =>
// TODO figure out why this is broken inside interpolation
val q = `null` *: integer *: real *: text *: blob *: nil
db.execute(sql"create table porcupine (n, i, r, t, b);".command) *>
db.execute(
sql"insert into porcupine values(${`null`}, $integer, $real, $text, $blob);".command,
sql"insert into porcupine values(${q});".command,
(None, 42L, 3.14, "quill-pig", ByteVector(0, 1, 2, 3)),
) *>
db.unique(
Expand Down
Loading