Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements for on method & tests #2

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ open class HubCommunicationTask : DefaultTask() {
returns = Unit::class.asTypeName(),
body = {
if (!reified) {
addStatement("%L", "on(target = target, hasResult = true)")
addStatement("%L", "on(target = target, hasResult = resultType != Unit::class)")
.addCode(
format = "%L",
"""
Expand Down
1 change: 1 addition & 0 deletions signalrkore/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ kotlin {
val commonTest by getting {
dependencies {
implementation(kotlin("test"))
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.8.0")
}
}
val jvmMain by getting {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package eu.lepicekmichal.signalrkore

import io.ktor.client.HttpClient
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.serialization.json.Json
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds
Expand All @@ -12,6 +14,16 @@ class HttpHubConnectionBuilder(private val url: String) {
*/
var transportEnum: TransportEnum = TransportEnum.All

/**
* The [Transport] to be used by the [eu.lepicekmichal.signalrkore.HubConnection]
*/
internal var transport: Transport? = null

/**
* The [CoroutineDispatcher] to be used by the [eu.lepicekmichal.signalrkore.HubConnection]
*/
internal var dispatcher: CoroutineDispatcher = Dispatchers.IO

/**
* The [HttpClient] to be used by the [eu.lepicekmichal.signalrkore.HubConnection]
*/
Expand Down Expand Up @@ -72,8 +84,10 @@ class HttpHubConnectionBuilder(private val url: String) {
if (::protocol.isInitialized) protocol else JsonHubProtocol(logger),
handshakeResponseTimeout,
headers.toMap(),
transport,
transportEnum,
json,
logger,
dispatcher
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -281,4 +281,149 @@ abstract class HubCommunicationLink(private val json: Json) : HubCommunication()
.filter { it.target == target }
.onEach { logger.log(Logger.Severity.INFO, "Received invocation: $it", null) }
}


fun on2(target: String, callback: suspend () -> Unit) {
on2(
target = target,
paramTypes = emptyList(),
callback = {
callback()
},
)
}

@Suppress("UNCHECKED_CAST")
fun <P1 : Any> on2(target: String, paramType1: KClass<P1>, callback: suspend (P1) -> Unit) {
on2(
target = target,
paramTypes = listOf(paramType1),
callback = {
callback(it[0] as P1)
},
)
}

inline fun <reified P1 : Any> on2(target: String, noinline callback: suspend (P1) -> Unit) {
on2(
target = target,
paramType1 = P1::class,
callback = callback,
)
}

@Suppress("UNCHECKED_CAST")
fun <P1 : Any, P2 : Any> on2(target: String, paramType1: KClass<P1>, paramType2: KClass<P2>, callback: suspend (P1, P2) -> Unit) {
on2(
target = target,
paramTypes = listOf(paramType1, paramType2),
callback = {
callback(it[0] as P1, it[1] as P2)
},
)
}

inline fun <reified P1 : Any, reified P2 : Any> on2(target: String, noinline callback: suspend (P1, P2) -> Unit) {
on2(
target = target,
paramType1 = P1::class,
paramType2 = P2::class,
callback = callback,
)
}

fun <RESULT: Any> onWithResult2(target: String, resultType: KClass<RESULT>, callback: suspend () -> RESULT) {
onWithResult2(
target = target,
resultType = resultType,
paramTypes = emptyList(),
callback = {
callback()
},
)
}

inline fun <reified RESULT : Any> onWithResult2(target: String, noinline callback: suspend () -> RESULT) {
onWithResult2(
target = target,
resultType = RESULT::class,
callback = callback,
)
}

@Suppress("UNCHECKED_CAST")
fun <RESULT: Any, P1 : Any> onWithResult2(target: String, paramType1: KClass<P1>, resultType: KClass<RESULT>, callback: suspend (P1) -> RESULT) {
onWithResult2(
target = target,
resultType = resultType,
paramTypes = listOf(paramType1),
callback = {
callback(it[0] as P1)
},
)
}

inline fun <reified RESULT : Any, reified P1 : Any> onWithResult2(target: String, noinline callback: suspend (P1) -> RESULT) {
onWithResult2(
target = target,
paramType1 = P1::class,
resultType = RESULT::class,
callback = callback,
)
}

@Suppress("UNCHECKED_CAST")
fun <RESULT: Any, P1 : Any, P2 : Any> onWithResult2(target: String, paramType1: KClass<P1>, paramType2: KClass<P2>, resultType: KClass<RESULT>, callback: suspend (P1, P2) -> RESULT) {
onWithResult2(
target = target,
resultType = resultType,
paramTypes = listOf(paramType1, paramType2),
callback = {
callback(it[0] as P1, it[1] as P2)
},
)
}

inline fun <reified RESULT : Any, reified P1 : Any, reified P2 : Any> onWithResult2(target: String, noinline callback: suspend (P1, P2) -> RESULT) {
onWithResult2(
target = target,
paramType1 = P1::class,
paramType2 = P2::class,
resultType = RESULT::class,
callback = callback,
)
}

fun <RESULT: Any> onWithResult2(target: String, resultType: KClass<RESULT>, paramTypes: List<KClass<*>>, callback: suspend (List<Any>) -> RESULT) {
if (!resultProviderRegistry.add(target)) {
throw IllegalStateException("There can be only one function for returning result on blocking invocation (method: $target)")
}

return receivedInvocations
.onCompletion { resultProviderRegistry.remove(target) }
.filter { it.target == target }
.onEach { logger.log(Logger.Severity.INFO, "Received invocation: $it", null) }
.handleIncomingInvocation(
resultType = resultType,
callback = {
callback(
it.arguments.mapIndexed { index, arg -> arg.fromJson(paramTypes[index]) }
)
},
)
}

fun on2(target: String, paramTypes: List<KClass<*>>, callback: suspend (List<Any>) -> Unit) {
return receivedInvocations
.filter { it.target == target }
.onEach { logger.log(Logger.Severity.INFO, "Received invocation: $it", null) }
.handleIncomingInvocation(
resultType = Unit::class,
callback = {
callback(
it.arguments.mapIndexed { index, arg -> arg.fromJson(paramTypes[index]) }
)
},
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import io.ktor.serialization.kotlinx.json.*
import io.ktor.util.*
import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
Expand Down Expand Up @@ -55,10 +56,11 @@ class HubConnection private constructor(
private val automaticReconnect: AutomaticReconnect,
override val logger: Logger,
json: Json,
dispatcher: CoroutineDispatcher,
) : HubCommunicationLink(json) {

private val job = SupervisorJob()
override val scope = CoroutineScope(job + Dispatchers.IO)
override val scope = CoroutineScope(job + dispatcher)

private val pingReset = MutableSharedFlow<Unit>(extraBufferCapacity = 1)
private val pingTicker = pingReset
Expand Down Expand Up @@ -95,9 +97,11 @@ class HubConnection private constructor(
protocol: HubProtocol,
handshakeResponseTimeout: Duration,
headers: Map<String, String>,
transport: Transport?,
transportEnum: TransportEnum,
json: Json,
logger: Logger,
dispatcher: CoroutineDispatcher,
) : this(
baseUrl = url.takeIf { it.isNotBlank() } ?: throw IllegalArgumentException("A valid url is required."),
protocol = protocol,
Expand All @@ -113,7 +117,12 @@ class HubConnection private constructor(
automaticReconnect = automaticReconnect,
json = json,
logger = logger,
)
dispatcher = dispatcher,
) {
if (transport != null) {
this.transport = transport
}
}

suspend fun start(reconnectionAttempt: Boolean = false) {
if (connectionState.value != HubConnectionState.DISCONNECTED && connectionState.value != HubConnectionState.RECONNECTING) return
Expand Down Expand Up @@ -141,10 +150,12 @@ class HubConnection private constructor(
Negotiation(TransportEnum.WebSockets, baseUrl)
}

transport = when (negotiationTransport) {
TransportEnum.LongPolling -> LongPollingTransport(headers, httpClient)
TransportEnum.ServerSentEvents -> ServerSentEventsTransport(headers, httpClient)
else -> WebSocketTransport(headers, httpClient)
if (!::transport.isInitialized) {
transport = when (negotiationTransport) {
TransportEnum.LongPolling -> LongPollingTransport(headers, httpClient)
TransportEnum.ServerSentEvents -> ServerSentEventsTransport(headers, httpClient)
else -> WebSocketTransport(headers, httpClient)
}
}

try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package eu.lepicekmichal.signalrkore

import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.FlowPreview
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.timeout
import kotlinx.coroutines.withContext
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds

class Completable {
private val stateFlow = MutableStateFlow(false)

@OptIn(FlowPreview::class)
suspend fun waitForCompletion(timeout: Duration = 5.seconds) = withContext(Dispatchers.Default) {
stateFlow.filter { it }.timeout(timeout).first()
}

fun reset() {
stateFlow.value = false
}

fun complete() {
stateFlow.value = true
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package eu.lepicekmichal.signalrkore

import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.test.runTest
import kotlin.test.BeforeTest
import kotlin.test.AfterTest

abstract class HubTest {
protected lateinit var hubConnection: HubConnection
protected lateinit var transport: MockTransport
protected lateinit var logger: TestLogger

@BeforeTest
fun setup() {
transport = MockTransport()
logger = TestLogger()
hubConnection = createHubConnection(transport) {
logger = [email protected]
// By using Unconfined dispatcher, we assure that on method will start collecting before the connection is started
dispatcher = Dispatchers.Unconfined
}
}

@AfterTest
fun cleanup() = runTest {
hubConnection.stop()
}

private fun createHubConnection(customTransport: Transport, block: HttpHubConnectionBuilder.() -> Unit = {}): HubConnection {
val builder = HttpHubConnectionBuilder("http://example.com").apply {
transport = customTransport
skipNegotiate = true
transportEnum = TransportEnum.WebSockets
}
builder.apply(block)

return builder.build()
}
}
Loading