fixed bug that prevented kilo client to restart if the channel was closed while channel send operation was suspended (featurebug of channels, one of)

This commit is contained in:
Sergey Chernov 2025-03-16 11:52:18 +03:00
parent 146878629e
commit 7e1f7ec4aa
6 changed files with 118 additions and 70 deletions

View File

@ -91,9 +91,11 @@ class KiloClient<S>(
debug { "get device and session" } debug { "get device and session" }
val client = KiloClientConnection(localInterface, kc, secretKey) val client = KiloClientConnection(localInterface, kc, secretKey)
deferredClient.complete(client) deferredClient.complete(client)
client.run { debug { "starting client run"}
val r = runCatching { client.run {
_state.value = it _state.value = it
} } }
debug { "----------- client run finished: $r" }
resetDeferredClient() resetDeferredClient()
debug { "client run finished" } debug { "client run finished" }
} catch (_: RemoteInterface.ClosedException) { } catch (_: RemoteInterface.ClosedException) {
@ -109,7 +111,7 @@ class KiloClient<S>(
_state.value = false _state.value = false
resetDeferredClient() resetDeferredClient()
// reconnection timeout // reconnection timeout
delay(100) delay(700)
} }
} }

View File

@ -49,7 +49,6 @@ class KiloClientConnection<S>(
try { try {
// in parallel: keys and connection // in parallel: keys and connection
val deferredKeyPair = async { SafeKeyExchange() } val deferredKeyPair = async { SafeKeyExchange() }
debug { "opening device" }
debug { "got a transport device $device" } debug { "got a transport device $device" }
@ -62,10 +61,11 @@ class KiloClientConnection<S>(
debug { "transport started" } debug { "transport started" }
val pair = deferredKeyPair.await() val pair = deferredKeyPair.await()
debug { "keypair ready" } debug { "keypair ready (1)" }
val serverHe = transport.call(L0Request, Handshake(1u, pair.publicKey)) val serverHe = transport.call(L0Request, Handshake(1u, pair.publicKey))
debug { "got server HE (2)" }
val sk = pair.clientSessionKey(serverHe.publicKey) val sk = pair.clientSessionKey(serverHe.publicKey)
var params = KiloParams(false, transport, sk, session, null, this@KiloClientConnection) var params = KiloParams(false, transport, sk, session, null, this@KiloClientConnection)
@ -97,8 +97,7 @@ class KiloClientConnection<S>(
} catch (x: CancellationException) { } catch (x: CancellationException) {
info { "client is cancelled" } info { "client is cancelled" }
} catch (x: RemoteInterface.ClosedException) { } catch (x: RemoteInterface.ClosedException) {
x.printStackTrace() debug { "connection closed/refused by remote" }
info { "connection closed by remote" }
} finally { } finally {
onConnectedStateChanged?.invoke(false) onConnectedStateChanged?.invoke(false)
job?.cancel() job?.cancel()

View File

@ -134,7 +134,13 @@ class Transport<S>(
} }
// now we have mutex freed so we can call: // now we have mutex freed so we can call:
val r = runCatching { device.output.send(pack(b)) } val r = runCatching {
do {
val cr = device.output.trySend(pack(b))
if( cr.isClosed ) throw ClosedSendChannelException("can't send block: channel is closed")
delay(100)
} while(!cr.isSuccess)
}
if (!r.isSuccess) { if (!r.isSuccess) {
r.exceptionOrNull()?.let { r.exceptionOrNull()?.let {
exception { "failed to send output block" to it } exception { "failed to send output block" to it }
@ -271,7 +277,7 @@ class Transport<S>(
} }
debug { "no more active: $isActive / ${calls.size}" } debug { "no more active: $isActive / ${calls.size}" }
} }
info { "exiting transport loop" } debug { "exiting transport loop" }
} }
private suspend fun send(block: Block) { private suspend fun send(block: Block) {

View File

@ -20,12 +20,10 @@ import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ClosedReceiveChannelException import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.channels.ClosedSendChannelException import kotlinx.coroutines.channels.ClosedSendChannelException
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.io.IOException
import net.sergeych.crypto2.SigningKey import net.sergeych.crypto2.SigningKey
import net.sergeych.kiloparsec.* import net.sergeych.kiloparsec.*
import net.sergeych.mp_logger.LogTag import net.sergeych.mp_logger.*
import net.sergeych.mp_logger.exception
import net.sergeych.mp_logger.info
import net.sergeych.mp_logger.warning
import net.sergeych.mp_tools.globalLaunch import net.sergeych.mp_tools.globalLaunch
import net.sergeych.tools.AtomicCounter import net.sergeych.tools.AtomicCounter
@ -67,8 +65,11 @@ fun websocketTransportDevice(
val input = Channel<UByteArray>() val input = Channel<UByteArray>()
val output = Channel<UByteArray>() val output = Channel<UByteArray>()
val closeHandle = CompletableDeferred<Boolean>() val closeHandle = CompletableDeferred<Boolean>()
val readyHandle = CompletableDeferred<Unit>()
globalLaunch { globalLaunch {
val log = LogTag("KC:${counter.incrementAndGet()}") val log = LogTag("KC:${counter.incrementAndGet()}")
try {
client.webSocket({ client.webSocket({
url.protocol = u.protocol url.protocol = u.protocol
url.host = u.host url.host = u.host
@ -80,6 +81,7 @@ fun websocketTransportDevice(
log.info { "connected to the server" } log.info { "connected to the server" }
// println("SENDING!!!") // println("SENDING!!!")
// send("Helluva") // send("Helluva")
readyHandle.complete(Unit)
launch { launch {
try { try {
for (block in output) { for (block in output) {
@ -120,17 +122,28 @@ fun websocketTransportDevice(
log.warning { "Client is closing with error" } log.warning { "Client is closing with error" }
throw RemoteInterface.ClosedException() throw RemoteInterface.ClosedException()
} }
output.close() runCatching { output.close() }
input.close() runCatching { input.close() }
runCatching { close() }
}
}
catch(x: IOException) {
if( "refused" in x.toString()) log.debug { "connection refused" }
else log.warning { "unexpected IO error $x" }
runCatching { output.close() }
runCatching { input.close() }
} }
log.info { "closing connection" } log.info { "closing connection" }
} }
val device = ProxyDevice(input, output) { // Wait for connection be established or failed
val device = ProxyDevice(input, output, doClose = {
// we need to explicitly close the coroutine job, or it can hang for a long time // we need to explicitly close the coroutine job, or it can hang for a long time
// leaking resources. // leaking resources.
runCatching { output.close() }
runCatching { input.close() }
closeHandle.complete(true) closeHandle.complete(true)
// job.cancel() // job.cancel()
} })
return device return device
} }

View File

@ -15,10 +15,13 @@ import io.ktor.server.engine.*
import io.ktor.server.netty.* import io.ktor.server.netty.*
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest import kotlinx.coroutines.test.runTest
import net.sergeych.crypto2.SigningSecretKey
import net.sergeych.crypto2.initCrypto import net.sergeych.crypto2.initCrypto
import net.sergeych.kiloparsec.adapter.setupWebsocketServer import net.sergeych.kiloparsec.adapter.setupWebsocketServer
import net.sergeych.kiloparsec.adapter.websocketClient import net.sergeych.kiloparsec.adapter.websocketClient
import net.sergeych.kiloparsec.adapter.websocketTransportDevice
import net.sergeych.mp_logger.Log import net.sergeych.mp_logger.Log
import java.net.InetAddress import java.net.InetAddress
import kotlin.random.Random import kotlin.random.Random
@ -50,6 +53,7 @@ class ClientTest {
Log.connectConsole(Log.Level.DEBUG) Log.connectConsole(Log.Level.DEBUG)
data class Session(var foo: String = "not set") data class Session(var foo: String = "not set")
var closeCounter = 0 var closeCounter = 0
val serverInterface = KiloInterface<Session>().apply { val serverInterface = KiloInterface<Session>().apply {
var connectedCalled = false var connectedCalled = false
@ -73,7 +77,9 @@ class ClientTest {
client.connectedStateFlow.collect { client.connectedStateFlow.collect {
println("got: $closeCounter/$it") println("got: $closeCounter/$it")
states += it states += it
if( !it) { closeCounter++ } if (!it) {
closeCounter++
}
} }
} }
assertEquals(true, client.call(cmdCheckConnected)) assertEquals(true, client.call(cmdCheckConnected))
@ -104,4 +110,26 @@ class ClientTest {
// println("stopped server") // println("stopped server")
// println("closed client") // println("closed client")
} }
@Test
fun webSocketWaitForConnectTest() = runBlocking {
initCrypto()
// fun Application.
Log.connectConsole(Log.Level.DEBUG)
val clientInterface = KiloInterface<Unit>().apply {}
val port = Random.nextInt(8080, 9090)
var clientConnectCalls = 0
// It should repeatedly reconnect, and we will count:
KiloClient(clientInterface, SigningSecretKey.new()) {
clientConnectCalls++
KiloConnectionData(websocketTransportDevice("ws://localhost:$port/kp"), Unit)
}
delay(1200)
// and check:
// println("connection attemtps: $clientConnectCalls")
assertTrue { clientConnectCalls > 1 }
}
} }