Skip to content

Commit

Permalink
add plain sql support for composite
Browse files Browse the repository at this point in the history
  • Loading branch information
tminglei committed Feb 21, 2015
1 parent a35d97d commit f72aeb9
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,22 @@ object PlainSQLUtils {
}

///
def mkArraySetParameter[T: ClassTag](baseType: String, toStr: (T => String) = (v: T) => v.toString): SetParameter[Seq[T]] =
def mkArraySetParameter[T: ClassTag](baseType: String, toStr: (T => String) = (v: T) => v.toString,
seqToStr: Option[(Seq[T] => String)] = None): SetParameter[Seq[T]] =
new SetParameter[Seq[T]] {
def apply(v: Seq[T], pp: PositionedParameters) = internalSetArray(baseType, Option(v), pp, toStr)
def apply(v: Seq[T], pp: PositionedParameters) = internalSetArray(baseType, Option(v), pp, toStr, seqToStr)
}
def mkArrayOptionSetParameter[T: ClassTag](baseType: String, toStr: (T => String) = (v: T) => v.toString): SetParameter[Option[Seq[T]]] =
def mkArrayOptionSetParameter[T: ClassTag](baseType: String, toStr: (T => String) = (v: T) => v.toString,
seqToStr: Option[(Seq[T] => String)] = None): SetParameter[Option[Seq[T]]] =
new SetParameter[Option[Seq[T]]] {
def apply(v: Option[Seq[T]], pp: PositionedParameters) = internalSetArray(baseType, v, pp, toStr)
def apply(v: Option[Seq[T]], pp: PositionedParameters) = internalSetArray(baseType, v, pp, toStr, seqToStr)
}

private def internalSetArray[T: ClassTag](baseType: String, v: Option[Seq[T]], p: PositionedParameters, toStr: (T => String)) = {
p.pos += 1
private def internalSetArray[T: ClassTag](baseType: String, v: Option[Seq[T]], p: PositionedParameters,
toStr: (T => String), seqToStr: Option[(Seq[T] => String)]) = {
val _seqToStr = seqToStr.getOrElse(mkString(toStr) _); p.pos += 1
v match {
case Some(vList) => p.ps.setArray(p.pos, mkArray(mkString[T](toStr))(baseType, vList))
case Some(vList) => p.ps.setArray(p.pos, mkArray(_seqToStr)(baseType, vList))
case None => p.ps.setNull(p.pos, java.sql.Types.ARRAY)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,37 @@ package com.github.tminglei.slickpg

import scala.slick.driver.PostgresDriver
import scala.reflect.runtime.{universe => u, currentMirror => rm}
import scala.slick.jdbc.JdbcType
import scala.reflect.ClassTag
import composite.Struct

import scala.slick.jdbc.PositionedResult

trait PgCompositeSupport extends utils.PgCommonJdbcTypes with array.PgArrayJdbcTypes { driver: PostgresDriver =>
import PgCompositeSupportUtils._

def createCompositeJdbcType[T <: Struct](sqlTypeName: String)(
implicit ev: u.TypeTag[T], tag: ClassTag[T]): JdbcType[T] = {
def createCompositeJdbcType[T <: Struct](sqlTypeName: String)(implicit ev: u.TypeTag[T], tag: ClassTag[T]) =
new GenericJdbcType[T](sqlTypeName, mkCompositeFromString[T], mkStringFromComposite[T])
}

def createCompositeListJdbcType[T <: Struct](sqlTypeName: String)(
implicit ev: u.TypeTag[T], tag: ClassTag[T], tag1: ClassTag[List[T]]): JdbcType[List[T]] = {
@deprecated(message = "pls use `createCompositeArrayJdbcType` instead", since = "0.8.2")
def createCompositeListJdbcType[T <: Struct](sqlTypeName: String)(implicit ev: u.TypeTag[T], tag: ClassTag[T]) =
createCompositeArrayJdbcType(sqlTypeName).to(_.toList)
def createCompositeArrayJdbcType[T <: Struct](sqlTypeName: String)(implicit ev: u.TypeTag[T], tag: ClassTag[T]) =
new AdvancedArrayJdbcType[T](sqlTypeName, mkCompositeSeqFromString[T], mkStringFromCompositeSeq[T])
.to(_.toList)
}

/// Plain SQL support
def nextComposite[T <: Struct](r: PositionedResult)(implicit ev: u.TypeTag[T], tag: ClassTag[T]) =
r.nextStringOption().map(mkCompositeFromString[T])
def nextCompositeArray[T <: Struct](r: PositionedResult)(implicit ev: u.TypeTag[T], tag: ClassTag[T]) =
r.nextStringOption().map(mkCompositeSeqFromString[T])

def createCompositeSetParameter[T <: Struct](sqlTypeName: String)(implicit ev: u.TypeTag[T], tag: ClassTag[T]) =
utils.PlainSQLUtils.mkSetParameter[T](sqlTypeName, mkStringFromComposite[T])
def createCompositeOptionSetParameter[T <: Struct](sqlTypeName: String)(implicit ev: u.TypeTag[T], tag: ClassTag[T]) =
utils.PlainSQLUtils.mkOptionSetParameter[T](sqlTypeName, mkStringFromComposite[T])
def createCompositeArraySetParameter[T <: Struct](sqlTypeName: String)(implicit ev: u.TypeTag[T], tag: ClassTag[T]) =
utils.PlainSQLUtils.mkArraySetParameter[T](sqlTypeName, seqToStr = Some(mkStringFromCompositeSeq[T]))
def createCompositeOptionArraySetParameter[T <: Struct](sqlTypeName: String)(implicit ev: u.TypeTag[T], tag: ClassTag[T]) =
utils.PlainSQLUtils.mkArrayOptionSetParameter[T](sqlTypeName, seqToStr = Some(mkStringFromCompositeSeq[T]))
}

object PgCompositeSupportUtils {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ object MyPlainPostgresDriver extends PostgresDriver
tpe match {
case tpe if tpe.typeConstructor =:= u.typeOf[LTree].typeConstructor =>
(true, r.nextStringOption().flatMap(fromString(LTree.apply)))
case _ => (false, None)
case _ => super.extNextArray(tpe, r)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ package com.github.tminglei.slickpg
import org.junit._
import org.junit.Assert._
import scala.slick.driver.PostgresDriver
import scala.slick.jdbc.{StaticQuery => Q}
import scala.slick.jdbc.{StaticQuery => Q, GetResult, PositionedResult}
import scala.reflect.runtime.{universe => u}
import java.sql.Timestamp
import java.text.SimpleDateFormat
import composite.Struct
Expand Down Expand Up @@ -38,6 +39,8 @@ object PgCompositeSupportTest {
override lazy val Implicit = new Implicits with ArrayImplicits with CompositeImplicits {}
override val simple = new SimpleQL with ArrayImplicits with CompositeImplicits {}

val plainImplicits = new Implicits with CompositePlainImplicits {}

///
trait CompositeImplicits {
utils.TypeConverters.register(PgRangeSupportUtils.mkRangeFn(ts))
Expand All @@ -46,9 +49,42 @@ object PgCompositeSupportTest {
implicit val composite1TypeMapper = createCompositeJdbcType[Composite1]("composite1")
implicit val composite2TypeMapper = createCompositeJdbcType[Composite2]("composite2")
implicit val composite3TypeMapper = createCompositeJdbcType[Composite3]("composite3")
implicit val composite1ArrayTypeMapper = createCompositeListJdbcType[Composite1]("composite1")
implicit val composite2ArrayTypeMapper = createCompositeListJdbcType[Composite2]("composite2")
implicit val composite3ArrayTypeMapper = createCompositeListJdbcType[Composite3]("composite3")
implicit val composite1ArrayTypeMapper = createCompositeArrayJdbcType[Composite1]("composite1").to(_.toList)
implicit val composite2ArrayTypeMapper = createCompositeArrayJdbcType[Composite2]("composite2").to(_.toList)
implicit val composite3ArrayTypeMapper = createCompositeArrayJdbcType[Composite3]("composite3").to(_.toList)
}

trait CompositePlainImplicits extends SimpleArrayPlainImplicits {
implicit class MyCompositePositionedResult(r: PositionedResult) {
def nextComposite1() = nextComposite[Composite1](r)
def nextComposite2() = nextComposite[Composite2](r)
def nextComposite3() = nextComposite[Composite3](r)
}
override protected def extNextArray(tpe: u.Type, r: PositionedResult): (Boolean, Option[Seq[_]]) =
tpe match {
case tpe if tpe.typeConstructor =:= u.typeOf[Composite1].typeConstructor =>
(true, nextCompositeArray[Composite1](r))
case tpe if tpe.typeConstructor =:= u.typeOf[Composite2].typeConstructor =>
(true, nextCompositeArray[Composite2](r))
case tpe if tpe.typeConstructor =:= u.typeOf[Composite3].typeConstructor =>
(true, nextCompositeArray[Composite3](r))
case _ => super.extNextArray(tpe, r)
}

implicit val composite1SetParameter = createCompositeSetParameter[Composite1]("composite1")
implicit val composite1OptSetParameter = createCompositeOptionSetParameter[Composite1]("composite1")
implicit val composite1ArraySetParameter = createCompositeArraySetParameter[Composite1]("composite1")
implicit val composite1ArrayOptSetParameter = createCompositeOptionArraySetParameter[Composite1]("composite1")

implicit val composite2SetParameter = createCompositeSetParameter[Composite2]("composite2")
implicit val composite2OptSetParameter = createCompositeOptionSetParameter[Composite2]("composite2")
implicit val composite2ArraySetParameter = createCompositeArraySetParameter[Composite2]("composite2")
implicit val composite2ArrayOptSetParameter = createCompositeOptionArraySetParameter[Composite2]("composite2")

implicit val composite3SetParameter = createCompositeSetParameter[Composite3]("composite3")
implicit val composite3OptSetParameter = createCompositeOptionSetParameter[Composite3]("composite3")
implicit val composite3ArraySetParameter = createCompositeArraySetParameter[Composite3]("composite3")
implicit val composite3ArrayOptSetParameter = createCompositeOptionArraySetParameter[Composite3]("composite3")
}
}
}
Expand Down Expand Up @@ -96,7 +132,6 @@ class PgCompositeSupportTest {

@Test
def testCompositeTypes(): Unit = {

db withSession { implicit session: Session =>
CompositeTests forceInsertAll (rec1, rec2, rec3)

Expand All @@ -118,7 +153,6 @@ class PgCompositeSupportTest {

@Test
def testCompositeTypes1(): Unit = {

db withSession { implicit session: Session =>
CompositeTests1 forceInsertAll (rec11, rec12, rec13)

Expand All @@ -132,6 +166,25 @@ class PgCompositeSupportTest {
assertEquals(rec13, q3.first)
}
}

@Test
def testPlainCompositeTypes(): Unit = {
import MyPostgresDriver1.plainImplicits._

implicit val getTestBeanResult = GetResult(r => TestBean(r.nextLong(), r.nextArray[Composite2]().toList))
implicit val getTestBean1Result = GetResult(r => TestBean1(r.nextLong(), r.nextArray[Composite3]().toList))

db withSession { implicit session: Session =>
(Q.u + "insert into \"CompositeTest\" values(" +? rec1.id + ", " +? rec1.comps + ")").execute
(Q.u + "insert into \"CompositeTest1\" values(" +? rec11.id + ", " +? rec11.comps + ")").execute

val found1 = (Q[TestBean] + "select * from \"CompositeTest\" where id = " +? rec1.id).first
val found11 = (Q[TestBean1] + "select * from \"CompositeTest1\" where id = " +? rec11.id).first

assertEquals(rec1, found1)
assertEquals(rec11, found11)
}
}

//////////////////////////////////////////////////////////////////////

Expand Down

0 comments on commit f72aeb9

Please sign in to comment.