Skip to content

Commit

Permalink
[IMPLICITS] Removed 'Strict' from some methods where it was not neces…
Browse files Browse the repository at this point in the history
…sary.
  • Loading branch information
eaplatanios committed Oct 26, 2018
1 parent b7368e4 commit 7677354
Show file tree
Hide file tree
Showing 17 changed files with 75 additions and 67 deletions.
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ crossScalaVersions in ThisBuild := Seq("2.11.12", "2.12.7")

organization in ThisBuild := "org.platanios"

autoCompilerPlugins in ThisBuild := true

val tensorFlowVersion = "1.11.0"
val circeVersion = "0.10.0" // Use for working with JSON.

autoCompilerPlugins in ThisBuild := true

// addCompilerPlugin(MetalsPlugin.semanticdbScalac)

scalacOptions in ThisBuild ++= Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,23 +167,23 @@ object DataTypeStructure {

implicit def fromHList[HD, TD <: HList](implicit
evH: Strict[DataTypeStructure[HD]],
evT: Strict[DataTypeStructure[TD]]
evT: DataTypeStructure[TD]
): DataTypeStructure[HD :: TD] = {
new DataTypeStructure[HD :: TD] {
override def size(dataType: HD :: TD): Int = {
evH.value.size(dataType.head) + evT.value.size(dataType.tail)
evH.value.size(dataType.head) + evT.size(dataType.tail)
}

override def dataTypes(dataType: HD :: TD): Seq[DataType[Any]] = {
evH.value.dataTypes(dataType.head) ++ evT.value.dataTypes(dataType.tail)
evH.value.dataTypes(dataType.head) ++ evT.dataTypes(dataType.tail)
}

override def decodeDataType(
dataType: HD :: TD,
dataTypes: Seq[DataType[Any]]
): (HD :: TD, Seq[DataType[Any]]) = {
val (headOut, headRemaining) = evH.value.decodeDataType(dataType.head, dataTypes)
val (tailOut, tailRemaining) = evT.value.decodeDataType(dataType.tail, headRemaining)
val (tailOut, tailRemaining) = evT.decodeDataType(dataType.tail, headRemaining)
(headOut :: tailOut, tailRemaining)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,13 @@ object DataTypeToOutput {

implicit def fromHList[HD, HO, TD <: HList, TO <: HList](implicit
evH: Strict[DataTypeToOutput.Aux[HD, HO]],
evT: Strict[DataTypeToOutput.Aux[TD, TO]]
evT: DataTypeToOutput.Aux[TD, TO]
): DataTypeToOutput.Aux[HD :: TD, HO :: TO] = {
new DataTypeToOutput[HD :: TD] {
override type O = HO :: TO

override def dataTypeStructure: DataTypeStructure[HD :: TD] = {
DataTypeStructure.fromHList[HD, TD](evH.value.dataTypeStructure, evT.value.dataTypeStructure)
DataTypeStructure.fromHList[HD, TD](evH.value.dataTypeStructure, evT.dataTypeStructure)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,21 +164,21 @@ object DataTypeToShape {

implicit def fromHList[HD, HS, TD <: HList, TS <: HList](implicit
evH: Strict[DataTypeToShape.Aux[HD, HS]],
evT: Strict[DataTypeToShape.Aux[TD, TS]]
evT: DataTypeToShape.Aux[TD, TS]
): DataTypeToShape.Aux[HD :: TD, HS :: TS] = {
new DataTypeToShape[HD :: TD] {
override type S = HS :: TS

override def sizeFromDataType(dataType: HD :: TD): Int = {
evH.value.sizeFromDataType(dataType.head) + evT.value.sizeFromDataType(dataType.tail)
evH.value.sizeFromDataType(dataType.head) + evT.sizeFromDataType(dataType.tail)
}

override def decodeShape(
dataType: HD :: TD,
shapes: Seq[Shape]
): (HS :: TS, Seq[Shape]) = {
val (headOut, headRemaining) = evH.value.decodeShape(dataType.head, shapes)
val (tailOut, tailRemaining) = evT.value.decodeShape(dataType.tail, headRemaining)
val (tailOut, tailRemaining) = evT.decodeShape(dataType.tail, headRemaining)
(headOut :: tailOut, tailRemaining)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ trait NestedStructureOpsLowPriority {

implicit def fromHList[H, T <: HList](implicit
evH: Strict[OpStructure[H]],
evT: Strict[OpStructure[T]]
evT: OpStructure[T]
): OpStructure[H :: T] = {
new OpStructure[H :: T] {
override def ops(executable: H :: T): Set[UntypedOp] = {
evH.value.ops(executable.head) ++
evT.value.ops(executable.tail)
evT.ops(executable.tail)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import scala.language.higherKinds
*
* @author Emmanouil Antonios Platanios
*/
sealed trait OutputStructure[T] {
trait OutputStructure[T] {
def size(output: T): Int
def outputs(output: T): Seq[Output[Any]]
def decodeOutput(output: T, outputs: Seq[Output[Any]]): (T, Seq[Output[Any]])
Expand Down Expand Up @@ -362,25 +362,25 @@ object OutputStructure {

implicit def fromHList[HT, TT <: HList](implicit
evH: Strict[OutputStructure[HT]],
evT: Strict[OutputStructure[TT]]
evT: OutputStructure[TT]
): OutputStructure[HT :: TT] = {
new OutputStructure[HT :: TT] {
override def size(output: HT :: TT): Int = {
evH.value.size(output.head) +
evT.value.size(output.tail)
evT.size(output.tail)
}

override def outputs(output: HT :: TT): Seq[Output[Any]] = {
evH.value.outputs(output.head) ++
evT.value.outputs(output.tail)
evT.outputs(output.tail)
}

override def decodeOutput(
output: HT :: TT,
outputs: Seq[Output[Any]]
): (HT :: TT, Seq[Output[Any]]) = {
val (headOut, headRemaining) = evH.value.decodeOutput(output.head, outputs)
val (tailOut, tailRemaining) = evT.value.decodeOutput(output.tail, headRemaining)
val (tailOut, tailRemaining) = evT.decodeOutput(output.tail, headRemaining)
(headOut :: tailOut, tailRemaining)
}

Expand All @@ -389,7 +389,7 @@ object OutputStructure {
converter: OutputStructure.Converter
): HT :: TT = {
evH.value.map(value.head, converter) ::
evT.value.map(value.tail, converter)
evT.map(value.tail, converter)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,29 +326,29 @@ object OutputToDataType {

implicit def fromHList[HT, HD, TT <: HList, TD <: HList](implicit
evH: Strict[OutputToDataType.Aux[HT, HD]],
evT: Strict[OutputToDataType.Aux[TT, TD]]
evT: OutputToDataType.Aux[TT, TD]
): OutputToDataType.Aux[HT :: TT, HD :: TD] = {
new OutputToDataType[HT :: TT] {
override type D = HD :: TD

override def dataTypeStructure: DataTypeStructure[HD :: TD] = {
DataTypeStructure.fromHList[HD, TD](evH.value.dataTypeStructure, evT.value.dataTypeStructure)
DataTypeStructure.fromHList[HD, TD](evH.value.dataTypeStructure, evT.dataTypeStructure)
}

override def sizeFromDataType(dataType: HD :: TD): Int = {
evH.value.sizeFromDataType(dataType.head) + evT.value.sizeFromDataType(dataType.tail)
evH.value.sizeFromDataType(dataType.head) + evT.sizeFromDataType(dataType.tail)
}

override def dataType(output: HT :: TT): HD :: TD = {
evH.value.dataType(output.head) :: evT.value.dataType(output.tail)
evH.value.dataType(output.head) :: evT.dataType(output.tail)
}

override def decodeOutput(
dataType: HD :: TD,
outputs: Seq[Output[Any]]
): (HT :: TT, Seq[Output[Any]]) = {
val (headOut, headRemaining) = evH.value.decodeOutput(dataType.head, outputs)
val (tailOut, tailRemaining) = evT.value.decodeOutput(dataType.tail, headRemaining)
val (tailOut, tailRemaining) = evT.decodeOutput(dataType.tail, headRemaining)
(headOut :: tailOut, tailRemaining)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,35 +452,35 @@ object OutputToShape {

implicit def fromHList[HT, HS, TT <: HList, TS <: HList](implicit
evH: Strict[OutputToShape.Aux[HT, HS]],
evT: Strict[OutputToShape.Aux[TT, TS]]
evT: OutputToShape.Aux[TT, TS]
): OutputToShape.Aux[HT :: TT, HS :: TS] = {
new OutputToShape[HT :: TT] {
override type S = HS :: TS

override def outputStructure: OutputStructure[HT :: TT] = {
implicit val evOutputToShapeH: OutputStructure[HT] = evH.value.outputStructure
implicit val evOutputToShapeT: OutputStructure[TT] = evT.value.outputStructure
implicit val evOutputToShapeT: OutputStructure[TT] = evT.outputStructure
OutputStructure[HT :: TT]
}

override def shapeStructure: ShapeStructure[HS :: TS] = {
ShapeStructure.fromHList[HS, TS](evH.value.shapeStructure, evT.value.shapeStructure)
ShapeStructure.fromHList[HS, TS](evH.value.shapeStructure, evT.shapeStructure)
}

override def sizeFromOutput(output: HT :: TT): Int = {
evH.value.sizeFromOutput(output.head) + evT.value.sizeFromOutput(output.tail)
evH.value.sizeFromOutput(output.head) + evT.sizeFromOutput(output.tail)
}

override def shape(output: HT :: TT): HS :: TS = {
evH.value.shape(output.head) :: evT.value.shape(output.tail)
evH.value.shape(output.head) :: evT.shape(output.tail)
}

override def decodeShape(
output: HT :: TT,
shapes: Seq[Shape]
): (HS :: TS, Seq[Shape]) = {
val (headOut, headRemaining) = evH.value.decodeShape(output.head, shapes)
val (tailOut, tailRemaining) = evT.value.decodeShape(output.tail, headRemaining)
val (tailOut, tailRemaining) = evT.decodeShape(output.tail, headRemaining)
(headOut :: tailOut, tailRemaining)
}

Expand All @@ -490,7 +490,7 @@ object OutputToShape {
converter: OutputStructure.Converter
): HT :: TT = {
evH.value.map(value.head, shape.map(_.head), converter) ::
evT.value.map(value.tail, shape.map(_.tail), converter)
evT.map(value.tail, shape.map(_.tail), converter)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,25 +283,25 @@ object OutputToTensor {

implicit def fromHList[HT, HV, TT <: HList, TV <: HList](implicit
evH: Strict[OutputToTensor.Aux[HT, HV]],
evT: Strict[OutputToTensor.Aux[TT, TV]]
evT: OutputToTensor.Aux[TT, TV]
): OutputToTensor.Aux[HT :: TT, HV :: TV] = {
new OutputToTensor[HT :: TT] {
override type V = HV :: TV

override def tensorStructure: TensorStructure[HV :: TV] = {
TensorStructure.fromHList[HV, TV](evH.value.tensorStructure, evT.value.tensorStructure)
TensorStructure.fromHList[HV, TV](evH.value.tensorStructure, evT.tensorStructure)
}

override def size(output: HT :: TT): Int = {
evH.value.size(output.head) + evT.value.size(output.tail)
evH.value.size(output.head) + evT.size(output.tail)
}

override def decodeTensor(
output: HT :: TT,
tensors: Seq[Tensor[Any]]
): (HV :: TV, Seq[Tensor[Any]]) = {
val (headOut, headRemaining) = evH.value.decodeTensor(output.head, tensors)
val (tailOut, tailRemaining) = evT.value.decodeTensor(output.tail, headRemaining)
val (tailOut, tailRemaining) = evT.decodeTensor(output.tail, headRemaining)
(headOut :: tailOut, tailRemaining)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,23 +163,23 @@ object ShapeStructure {

implicit def fromHList[HS, TS <: HList](implicit
evH: Strict[ShapeStructure[HS]],
evT: Strict[ShapeStructure[TS]]
evT: ShapeStructure[TS]
): ShapeStructure[HS :: TS] = {
new ShapeStructure[HS :: TS] {
override def size(shape: HS :: TS): Int = {
evH.value.size(shape.head) + evT.value.size(shape.tail)
evH.value.size(shape.head) + evT.size(shape.tail)
}

override def shapes(shape: HS :: TS): Seq[Shape] = {
evH.value.shapes(shape.head) ++ evT.value.shapes(shape.tail)
evH.value.shapes(shape.head) ++ evT.shapes(shape.tail)
}

override def decodeShape(
shape: HS :: TS,
shapes: Seq[Shape]
): (HS :: TS, Seq[Shape]) = {
val (headOut, headRemaining) = evH.value.decodeShape(shape.head, shapes)
val (tailOut, tailRemaining) = evT.value.decodeShape(shape.tail, headRemaining)
val (tailOut, tailRemaining) = evT.decodeShape(shape.tail, headRemaining)
(headOut :: tailOut, tailRemaining)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ object TensorStructure {

implicit def fromHList[HT, TT <: HList](implicit
evH: Strict[TensorStructure[HT]],
evT: Strict[TensorStructure[TT]]
evT: TensorStructure[TT]
): TensorStructure[HT :: TT] = {
new TensorStructure[HT :: TT] {
override def tensors(tensor: HT :: TT): Seq[Tensor[Any]] = {
evH.value.tensors(tensor.head) ++ evT.value.tensors(tensor.tail)
evH.value.tensors(tensor.head) ++ evT.tensors(tensor.tail)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,13 @@ object TensorToDataType {

implicit def fromHList[HT, HD, TT <: HList, TD <: HList](implicit
evH: Strict[TensorToDataType.Aux[HT, HD]],
evT: Strict[TensorToDataType.Aux[TT, TD]]
evT: TensorToDataType.Aux[TT, TD]
): TensorToDataType.Aux[HT :: TT, HD :: TD] = {
new TensorToDataType[HT :: TT] {
override type D = HD :: TD

override def dataType(output: HT :: TT): HD :: TD = {
evH.value.dataType(output.head) :: evT.value.dataType(output.tail)
evH.value.dataType(output.head) :: evT.dataType(output.tail)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,17 @@ object TensorToOutput {

implicit def fromHList[HT, HO, TT <: HList, TO <: HList](implicit
evH: Strict[TensorToOutput.Aux[HT, HO]],
evT: Strict[TensorToOutput.Aux[TT, TO]]
evT: TensorToOutput.Aux[TT, TO]
): TensorToOutput.Aux[HT :: TT, HO :: TO] = {
new TensorToOutput[HT :: TT] {
override type O = HO :: TO

override def tensorStructure: TensorStructure[HT :: TT] = {
TensorStructure.fromHList[HT, TT](evH.value.tensorStructure, evT.value.tensorStructure)
TensorStructure.fromHList[HT, TT](evH.value.tensorStructure, evT.tensorStructure)
}

override def output(tensor: HT :: TT): HO :: TO = {
evH.value.output(tensor.head) :: evT.value.output(tensor.tail)
evH.value.output(tensor.head) :: evT.output(tensor.tail)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,13 @@ object TensorToShape {

implicit def fromHList[HT, HS, TT <: HList, TS <: HList](implicit
evH: Strict[TensorToShape.Aux[HT, HS]],
evT: Strict[TensorToShape.Aux[TT, TS]]
evT: TensorToShape.Aux[TT, TS]
): TensorToShape.Aux[HT :: TT, HS :: TS] = {
new TensorToShape[HT :: TT] {
override type S = HS :: TS

override def shape(output: HT :: TT): HS :: TS = {
evH.value.shape(output.head) :: evT.value.shape(output.tail)
evH.value.shape(output.head) :: evT.shape(output.tail)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,13 @@ object Zero {

implicit def fromHList[HT, HS, TT <: HList, TS <: HList](implicit
evH: Strict[Zero.Aux[HT, HS]],
evT: Strict[Zero.Aux[TT, TS]]
evT: Zero.Aux[TT, TS]
): Zero.Aux[HT :: TT, HS :: TS] = {
new Zero[HT :: TT] {
override type S = HS :: TS

override def evOutputToShape: OutputToShape.Aux[HT :: TT, HS :: TS] = {
OutputToShape.fromHList[HT, HS, TT, TS](evH.value.evOutputToShape, evT.value.evOutputToShape)
OutputToShape.fromHList[HT, HS, TT, TS](evH.value.evOutputToShape, evT.evOutputToShape)
}

override def zero(
Expand All @@ -198,7 +198,7 @@ object Zero {
): HT :: TT = {
Op.nameScope(name) {
evH.value.zero(batchSize, shape.head) ::
evT.value.zero(batchSize, shape.tail)
evT.zero(batchSize, shape.tail)
}
}
}
Expand Down
Loading

0 comments on commit 7677354

Please sign in to comment.