Skip to content

Commit

Permalink
Fixed PR Issues and extended Commitment protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
adityanathan committed Sep 23, 2024
1 parent dbb7419 commit 914471a
Show file tree
Hide file tree
Showing 12 changed files with 94 additions and 40 deletions.
21 changes: 21 additions & 0 deletions Commitment3.circuit
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
host alice
host bob
host chuck
host david

circuit fun <> move@Local(host = alice)(a: int[]) -> b: int[] {
return a
}

circuit fun <> move2@Replication(hosts = {bob, david})(a: int[]) -> b: int[] {
return a
}

fun <> main() -> {
val a@Local(host = alice) = alice.input<int[]>()
val c@Commitment(sender = alice, receivers = {bob, chuck}) = move<>(a)
val d@Replication(hosts = {bob, david}) = move2<>(c)
val = bob.output<int[]>(d)
val = david.output<int[]>(d)
return
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package io.github.aplcornell.viaduct.backends.cleartext

import com.squareup.kotlinpoet.BYTE_ARRAY
import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.MemberName
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
import com.squareup.kotlinpoet.asClassName
import com.squareup.kotlinpoet.asTypeName
import com.squareup.kotlinpoet.TypeName
import io.github.aplcornell.viaduct.circuitcodegeneration.AbstractCodeGenerator
import io.github.aplcornell.viaduct.circuitcodegeneration.Argument
import io.github.aplcornell.viaduct.circuitcodegeneration.CodeGeneratorContext
Expand All @@ -22,6 +24,8 @@ import io.github.aplcornell.viaduct.syntax.UnaryOperator
import io.github.aplcornell.viaduct.syntax.circuit.OperatorNode
import io.github.aplcornell.viaduct.syntax.operators.Maximum
import io.github.aplcornell.viaduct.syntax.operators.Minimum
import io.github.aplcornell.viaduct.syntax.types.ByteVecType
import io.github.aplcornell.viaduct.syntax.values.HostSetValue
import io.github.aplcornell.viaduct.backends.commitment.Commitment as CommitmentProtocol

class CleartextCircuitCodeGenerator(context: CodeGeneratorContext) : AbstractCodeGenerator(context) {
Expand Down Expand Up @@ -121,6 +125,28 @@ class CleartextCircuitCodeGenerator(context: CodeGeneratorContext) : AbstractCod
}
}

private fun checkPeerValues(
peers: HostSetValue,
value: CodeBlock,
valueType: TypeName,
builder: CodeBlock.Builder,
) {
val receivingPeers = peers.filter { it != context.host }
if (receivingPeers.isNotEmpty()) {
for (host in receivingPeers) builder.addStatement("%L", context.send(value, host))
builder.addStatement(
"%L",
receiveExpected(
value,
context.host,
valueType,
receivingPeers,
context,
),
)
}
}

private fun createCommitment(
source: Protocol,
target: Protocol,
Expand All @@ -131,23 +157,21 @@ class CleartextCircuitCodeGenerator(context: CodeGeneratorContext) : AbstractCod
if (source !is Local) {
throw UnsupportedCommunicationException(source, target, argument.sourceLocation)
}
require(source.hosts.size == 1 && source.host in source.hosts)
require(target is CommitmentProtocol)
if (target.cleartextHost != source.host || target.cleartextHost in target.hashHosts) {
if (target.cleartextHost != source.host) {
throw UnsupportedCommunicationException(source, target, argument.sourceLocation)
}

val argType = kotlinType(argument.type.shape, typeTranslator(argument.type.elementType.value))
val sendingHost = target.cleartextHost
val receivingHosts = target.hashHosts
return when (context.host) {
sendingHost -> {
val tempName1 = context.newTemporary("CommitTemp")
val tempName2 = context.newTemporary("CommitTemp")
val tempName1 = context.newTemporary("committed")
val tempName2 = context.newTemporary("commitment")
builder.addStatement(
"val %N = %T(%L)",
tempName1,
(Committed::class).asTypeName().parameterizedBy(argType),
(Committed::class).asTypeName(),
argument.value,
)
builder.addStatement(
Expand All @@ -163,12 +187,14 @@ class CleartextCircuitCodeGenerator(context: CodeGeneratorContext) : AbstractCod
}

in receivingHosts -> {
val tempName3 = context.newTemporary("CommitTemp")
val argType = kotlinType(argument.type.shape, typeTranslator(argument.type.elementType.value))
val tempName3 = context.newTemporary("commitment")
builder.addStatement(
"val %N = %L",
tempName3,
context.receive((Commitment::class).asTypeName().parameterizedBy(argType), source.host),
)
checkPeerValues(HostSetValue(receivingHosts), CodeBlock.of("%L.hash", tempName3), BYTE_ARRAY, builder)
CodeBlock.of("%N", tempName3)
}

Expand All @@ -187,57 +213,47 @@ class CleartextCircuitCodeGenerator(context: CodeGeneratorContext) : AbstractCod
throw UnsupportedCommunicationException(source, target, argument.sourceLocation)
}
require(context.host in source.hosts + target.hosts)
if (source.hashHosts != target.hosts || source.cleartextHost in source.hashHosts) {
throw UnsupportedCommunicationException(source, target, argument.sourceLocation)
}

val argType = kotlinType(argument.type.shape, typeTranslator(argument.type.elementType.value))
val sendingHost = source.cleartextHost
val commitmentReceivingHosts = target.hosts.filter { it in source.hashHosts }
val cleartextReceivingHosts = target.hosts - commitmentReceivingHosts
val omittedHosts = source.hashHosts - target.hosts
val receivingHosts = target.hosts
return when (context.host) {
sendingHost -> {
receivingHosts.forEach {
commitmentReceivingHosts.forEach {
builder.addStatement("%L", context.send(argument.value, it))
}
cleartextReceivingHosts.forEach {
builder.addStatement("%L", context.send(CodeBlock.of("%L.value", argument.value), it))
}
CodeBlock.of("%L.value", argument.value)
}
in receivingHosts -> {
val tempName1 = context.newTemporary("CommitTemp")
in commitmentReceivingHosts -> {
val tempName1 = context.newTemporary("commitTemp")
builder.addStatement(
"val %N = %L",
"val %N = %L.%N(%L)",
tempName1,
argument.value,
"open",
context.receive((Committed::class).asTypeName().parameterizedBy(argType), source.cleartextHost),
)
val tempName2 = context.newTemporary("CommitTemp")
checkPeerValues(receivingHosts, CodeBlock.of(tempName1), argType, builder)
CodeBlock.of("%N", tempName1)
}
in cleartextReceivingHosts -> {
val tempName1 = context.newTemporary("cleartextTemp")
builder.addStatement(
"val %N = %L",
tempName2,
argument.value,
)
val tempName3 = context.newTemporary("CommitTemp")
builder.addStatement(
"val %N = %N.%N(%N)",
tempName3,
tempName2,
"open",
tempName1,
context.receive(argType, sendingHost),
)

val peers = receivingHosts.filter { it != context.host }
if (peers.isNotEmpty()) {
for (host in peers) builder.addStatement("%L", context.send(CodeBlock.of(tempName3), host))
builder.addStatement(
"%L",
receiveExpected(
CodeBlock.of(tempName3),
context.host,
argType,
peers,
context,
),
)
}
CodeBlock.of("%N", tempName3)
checkPeerValues(receivingHosts, CodeBlock.of(tempName1), argType, builder)
CodeBlock.of("%N", tempName1)
}
in omittedHosts -> {
CodeBlock.of("")
}
else -> throw IllegalStateException()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ fun <> main() -> {
val a@Local(host = alice) = alice.input<int[]>()
val c@Commitment(sender = alice, receivers = {bob, chuck}) = move<>(a)
val d@Replication(hosts = {bob, david}) = move2<>(c)
val = bob.output<int[]>(d)
val = david.output<int[]>(d)
return
}
1 change: 1 addition & 0 deletions examples/inputs/circuit/cleartext/Commitment3-alice.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
5
Empty file.
Empty file.
Empty file.
Empty file.
1 change: 1 addition & 0 deletions examples/outputs/circuit/cleartext/Commitment3-bob.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
5
1 change: 1 addition & 0 deletions examples/outputs/circuit/cleartext/Commitment3-chuck.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

1 change: 1 addition & 0 deletions examples/outputs/circuit/cleartext/Commitment3-david.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
5
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,17 @@ class EquivocationException private constructor(
throw EquivocationException(expectedValue, expectedValueProvider, actualValue, actualValueProvider)
}
}

/** Throws [EquivocationException] if [expectedValue] does not match [actualValue]. */
fun assertEquals(
expectedValue: ByteArray,
expectedValueProvider: Host,
actualValue: ByteArray,
actualValueProvider: Host,
) {
if (!expectedValue.contentEquals(actualValue)) {
throw EquivocationException(expectedValue, expectedValueProvider, actualValue, actualValueProvider)
}
}
}
}

0 comments on commit 914471a

Please sign in to comment.