fixed bug that prevented kilo client to restart if the channel was closed while channel send operation was suspended (featurebug of channels, one of)
This commit is contained in:
		
							parent
							
								
									146878629e
								
							
						
					
					
						commit
						7e1f7ec4aa
					
				@ -91,9 +91,11 @@ class KiloClient<S>(
 | 
				
			|||||||
                    debug { "get device and session" }
 | 
					                    debug { "get device and session" }
 | 
				
			||||||
                    val client = KiloClientConnection(localInterface, kc, secretKey)
 | 
					                    val client = KiloClientConnection(localInterface, kc, secretKey)
 | 
				
			||||||
                    deferredClient.complete(client)
 | 
					                    deferredClient.complete(client)
 | 
				
			||||||
                    client.run {
 | 
					                    debug { "starting client run"}
 | 
				
			||||||
 | 
					                    val r = runCatching {  client.run {
 | 
				
			||||||
                        _state.value = it
 | 
					                        _state.value = it
 | 
				
			||||||
                    }
 | 
					                    } }
 | 
				
			||||||
 | 
					                    debug { "----------- client run finished: $r" }
 | 
				
			||||||
                    resetDeferredClient()
 | 
					                    resetDeferredClient()
 | 
				
			||||||
                    debug { "client run finished" }
 | 
					                    debug { "client run finished" }
 | 
				
			||||||
                } catch (_: RemoteInterface.ClosedException) {
 | 
					                } catch (_: RemoteInterface.ClosedException) {
 | 
				
			||||||
@ -109,7 +111,7 @@ class KiloClient<S>(
 | 
				
			|||||||
                _state.value = false
 | 
					                _state.value = false
 | 
				
			||||||
                resetDeferredClient()
 | 
					                resetDeferredClient()
 | 
				
			||||||
                // reconnection timeout
 | 
					                // reconnection timeout
 | 
				
			||||||
                delay(100)
 | 
					                delay(700)
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -49,7 +49,6 @@ class KiloClientConnection<S>(
 | 
				
			|||||||
            try {
 | 
					            try {
 | 
				
			||||||
                // in parallel: keys and connection
 | 
					                // in parallel: keys and connection
 | 
				
			||||||
                val deferredKeyPair = async { SafeKeyExchange() }
 | 
					                val deferredKeyPair = async { SafeKeyExchange() }
 | 
				
			||||||
                debug { "opening device" }
 | 
					 | 
				
			||||||
                debug { "got a transport device $device" }
 | 
					                debug { "got a transport device $device" }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -62,10 +61,11 @@ class KiloClientConnection<S>(
 | 
				
			|||||||
                debug { "transport started" }
 | 
					                debug { "transport started" }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                val pair = deferredKeyPair.await()
 | 
					                val pair = deferredKeyPair.await()
 | 
				
			||||||
                debug { "keypair ready" }
 | 
					                debug { "keypair ready (1)" }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                val serverHe = transport.call(L0Request, Handshake(1u, pair.publicKey))
 | 
					                val serverHe = transport.call(L0Request, Handshake(1u, pair.publicKey))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                debug { "got server HE (2)" }
 | 
				
			||||||
                val sk = pair.clientSessionKey(serverHe.publicKey)
 | 
					                val sk = pair.clientSessionKey(serverHe.publicKey)
 | 
				
			||||||
                var params = KiloParams(false, transport, sk, session, null, this@KiloClientConnection)
 | 
					                var params = KiloParams(false, transport, sk, session, null, this@KiloClientConnection)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -97,8 +97,7 @@ class KiloClientConnection<S>(
 | 
				
			|||||||
            } catch (x: CancellationException) {
 | 
					            } catch (x: CancellationException) {
 | 
				
			||||||
                info { "client is cancelled" }
 | 
					                info { "client is cancelled" }
 | 
				
			||||||
            } catch (x: RemoteInterface.ClosedException) {
 | 
					            } catch (x: RemoteInterface.ClosedException) {
 | 
				
			||||||
                x.printStackTrace()
 | 
					                debug { "connection closed/refused by remote" }
 | 
				
			||||||
                info { "connection closed by remote" }
 | 
					 | 
				
			||||||
            } finally {
 | 
					            } finally {
 | 
				
			||||||
                onConnectedStateChanged?.invoke(false)
 | 
					                onConnectedStateChanged?.invoke(false)
 | 
				
			||||||
                job?.cancel()
 | 
					                job?.cancel()
 | 
				
			||||||
 | 
				
			|||||||
@ -134,7 +134,13 @@ class Transport<S>(
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // now we have mutex freed so we can call:
 | 
					        // now we have mutex freed so we can call:
 | 
				
			||||||
        val r = runCatching { device.output.send(pack(b)) }
 | 
					        val r = runCatching {
 | 
				
			||||||
 | 
					            do {
 | 
				
			||||||
 | 
					                val cr = device.output.trySend(pack(b))
 | 
				
			||||||
 | 
					                if( cr.isClosed ) throw ClosedSendChannelException("can't send block: channel is closed")
 | 
				
			||||||
 | 
					                delay(100)
 | 
				
			||||||
 | 
					            } while(!cr.isSuccess)
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        if (!r.isSuccess) {
 | 
					        if (!r.isSuccess) {
 | 
				
			||||||
            r.exceptionOrNull()?.let {
 | 
					            r.exceptionOrNull()?.let {
 | 
				
			||||||
                exception { "failed to send output block" to it }
 | 
					                exception { "failed to send output block" to it }
 | 
				
			||||||
@ -271,7 +277,7 @@ class Transport<S>(
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
            debug { "no more active: $isActive / ${calls.size}" }
 | 
					            debug { "no more active: $isActive / ${calls.size}" }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        info { "exiting transport loop" }
 | 
					        debug { "exiting transport loop" }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private suspend fun send(block: Block) {
 | 
					    private suspend fun send(block: Block) {
 | 
				
			||||||
 | 
				
			|||||||
@ -20,12 +20,10 @@ import kotlinx.coroutines.channels.Channel
 | 
				
			|||||||
import kotlinx.coroutines.channels.ClosedReceiveChannelException
 | 
					import kotlinx.coroutines.channels.ClosedReceiveChannelException
 | 
				
			||||||
import kotlinx.coroutines.channels.ClosedSendChannelException
 | 
					import kotlinx.coroutines.channels.ClosedSendChannelException
 | 
				
			||||||
import kotlinx.coroutines.launch
 | 
					import kotlinx.coroutines.launch
 | 
				
			||||||
 | 
					import kotlinx.io.IOException
 | 
				
			||||||
import net.sergeych.crypto2.SigningKey
 | 
					import net.sergeych.crypto2.SigningKey
 | 
				
			||||||
import net.sergeych.kiloparsec.*
 | 
					import net.sergeych.kiloparsec.*
 | 
				
			||||||
import net.sergeych.mp_logger.LogTag
 | 
					import net.sergeych.mp_logger.*
 | 
				
			||||||
import net.sergeych.mp_logger.exception
 | 
					 | 
				
			||||||
import net.sergeych.mp_logger.info
 | 
					 | 
				
			||||||
import net.sergeych.mp_logger.warning
 | 
					 | 
				
			||||||
import net.sergeych.mp_tools.globalLaunch
 | 
					import net.sergeych.mp_tools.globalLaunch
 | 
				
			||||||
import net.sergeych.tools.AtomicCounter
 | 
					import net.sergeych.tools.AtomicCounter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -67,8 +65,11 @@ fun websocketTransportDevice(
 | 
				
			|||||||
    val input = Channel<UByteArray>()
 | 
					    val input = Channel<UByteArray>()
 | 
				
			||||||
    val output = Channel<UByteArray>()
 | 
					    val output = Channel<UByteArray>()
 | 
				
			||||||
    val closeHandle = CompletableDeferred<Boolean>()
 | 
					    val closeHandle = CompletableDeferred<Boolean>()
 | 
				
			||||||
 | 
					    val readyHandle = CompletableDeferred<Unit>()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    globalLaunch {
 | 
					    globalLaunch {
 | 
				
			||||||
        val log = LogTag("KC:${counter.incrementAndGet()}")
 | 
					        val log = LogTag("KC:${counter.incrementAndGet()}")
 | 
				
			||||||
 | 
					        try {
 | 
				
			||||||
            client.webSocket({
 | 
					            client.webSocket({
 | 
				
			||||||
                url.protocol = u.protocol
 | 
					                url.protocol = u.protocol
 | 
				
			||||||
                url.host = u.host
 | 
					                url.host = u.host
 | 
				
			||||||
@ -80,6 +81,7 @@ fun websocketTransportDevice(
 | 
				
			|||||||
                log.info { "connected to the server" }
 | 
					                log.info { "connected to the server" }
 | 
				
			||||||
//                println("SENDING!!!")
 | 
					//                println("SENDING!!!")
 | 
				
			||||||
//                send("Helluva")
 | 
					//                send("Helluva")
 | 
				
			||||||
 | 
					                readyHandle.complete(Unit)
 | 
				
			||||||
                launch {
 | 
					                launch {
 | 
				
			||||||
                    try {
 | 
					                    try {
 | 
				
			||||||
                        for (block in output) {
 | 
					                        for (block in output) {
 | 
				
			||||||
@ -120,17 +122,28 @@ fun websocketTransportDevice(
 | 
				
			|||||||
                    log.warning { "Client is closing with error" }
 | 
					                    log.warning { "Client is closing with error" }
 | 
				
			||||||
                    throw RemoteInterface.ClosedException()
 | 
					                    throw RemoteInterface.ClosedException()
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            output.close()
 | 
					                runCatching { output.close() }
 | 
				
			||||||
            input.close()
 | 
					                runCatching { input.close() }
 | 
				
			||||||
 | 
					                runCatching { close() }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        catch(x: IOException) {
 | 
				
			||||||
 | 
					            if( "refused" in x.toString()) log.debug { "connection refused" }
 | 
				
			||||||
 | 
					            else log.warning { "unexpected IO error $x" }
 | 
				
			||||||
 | 
					            runCatching { output.close() }
 | 
				
			||||||
 | 
					            runCatching { input.close() }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        log.info { "closing connection" }
 | 
					        log.info { "closing connection" }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    val device = ProxyDevice(input, output) {
 | 
					    // Wait for connection be established or failed
 | 
				
			||||||
 | 
					    val device = ProxyDevice(input, output, doClose = {
 | 
				
			||||||
        // we need to explicitly close the coroutine job, or it can hang for a long time
 | 
					        // we need to explicitly close the coroutine job, or it can hang for a long time
 | 
				
			||||||
        // leaking resources.
 | 
					        // leaking resources.
 | 
				
			||||||
 | 
					        runCatching { output.close() }
 | 
				
			||||||
 | 
					        runCatching { input.close() }
 | 
				
			||||||
        closeHandle.complete(true)
 | 
					        closeHandle.complete(true)
 | 
				
			||||||
//            job.cancel()
 | 
					//            job.cancel()
 | 
				
			||||||
    }
 | 
					    })
 | 
				
			||||||
    return device
 | 
					    return device
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -15,10 +15,13 @@ import io.ktor.server.engine.*
 | 
				
			|||||||
import io.ktor.server.netty.*
 | 
					import io.ktor.server.netty.*
 | 
				
			||||||
import kotlinx.coroutines.delay
 | 
					import kotlinx.coroutines.delay
 | 
				
			||||||
import kotlinx.coroutines.launch
 | 
					import kotlinx.coroutines.launch
 | 
				
			||||||
 | 
					import kotlinx.coroutines.runBlocking
 | 
				
			||||||
import kotlinx.coroutines.test.runTest
 | 
					import kotlinx.coroutines.test.runTest
 | 
				
			||||||
 | 
					import net.sergeych.crypto2.SigningSecretKey
 | 
				
			||||||
import net.sergeych.crypto2.initCrypto
 | 
					import net.sergeych.crypto2.initCrypto
 | 
				
			||||||
import net.sergeych.kiloparsec.adapter.setupWebsocketServer
 | 
					import net.sergeych.kiloparsec.adapter.setupWebsocketServer
 | 
				
			||||||
import net.sergeych.kiloparsec.adapter.websocketClient
 | 
					import net.sergeych.kiloparsec.adapter.websocketClient
 | 
				
			||||||
 | 
					import net.sergeych.kiloparsec.adapter.websocketTransportDevice
 | 
				
			||||||
import net.sergeych.mp_logger.Log
 | 
					import net.sergeych.mp_logger.Log
 | 
				
			||||||
import java.net.InetAddress
 | 
					import java.net.InetAddress
 | 
				
			||||||
import kotlin.random.Random
 | 
					import kotlin.random.Random
 | 
				
			||||||
@ -50,6 +53,7 @@ class ClientTest {
 | 
				
			|||||||
        Log.connectConsole(Log.Level.DEBUG)
 | 
					        Log.connectConsole(Log.Level.DEBUG)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        data class Session(var foo: String = "not set")
 | 
					        data class Session(var foo: String = "not set")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        var closeCounter = 0
 | 
					        var closeCounter = 0
 | 
				
			||||||
        val serverInterface = KiloInterface<Session>().apply {
 | 
					        val serverInterface = KiloInterface<Session>().apply {
 | 
				
			||||||
            var connectedCalled = false
 | 
					            var connectedCalled = false
 | 
				
			||||||
@ -73,7 +77,9 @@ class ClientTest {
 | 
				
			|||||||
            client.connectedStateFlow.collect {
 | 
					            client.connectedStateFlow.collect {
 | 
				
			||||||
                println("got: $closeCounter/$it")
 | 
					                println("got: $closeCounter/$it")
 | 
				
			||||||
                states += it
 | 
					                states += it
 | 
				
			||||||
                if( !it) { closeCounter++ }
 | 
					                if (!it) {
 | 
				
			||||||
 | 
					                    closeCounter++
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        assertEquals(true, client.call(cmdCheckConnected))
 | 
					        assertEquals(true, client.call(cmdCheckConnected))
 | 
				
			||||||
@ -104,4 +110,26 @@ class ClientTest {
 | 
				
			|||||||
//        println("stopped server")
 | 
					//        println("stopped server")
 | 
				
			||||||
//        println("closed client")
 | 
					//        println("closed client")
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    fun webSocketWaitForConnectTest() = runBlocking {
 | 
				
			||||||
 | 
					        initCrypto()
 | 
				
			||||||
 | 
					//        fun Application.
 | 
				
			||||||
 | 
					        Log.connectConsole(Log.Level.DEBUG)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        val clientInterface = KiloInterface<Unit>().apply {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        val port = Random.nextInt(8080, 9090)
 | 
				
			||||||
 | 
					        var clientConnectCalls = 0
 | 
				
			||||||
 | 
					        // It should repeatedly reconnect, and we will count:
 | 
				
			||||||
 | 
					        KiloClient(clientInterface, SigningSecretKey.new()) {
 | 
				
			||||||
 | 
					            clientConnectCalls++
 | 
				
			||||||
 | 
					            KiloConnectionData(websocketTransportDevice("ws://localhost:$port/kp"), Unit)
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        delay(1200)
 | 
				
			||||||
 | 
					        // and check:
 | 
				
			||||||
 | 
					//        println("connection attemtps: $clientConnectCalls")
 | 
				
			||||||
 | 
					        assertTrue { clientConnectCalls > 1 }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user