work in porgress: fixing close connection

This commit is contained in:
Sergey Chernov 2024-03-23 01:29:48 +01:00
parent c0ca802a30
commit 93dc66acc5
13 changed files with 146 additions and 39 deletions

View File

@ -31,7 +31,8 @@ class Command<A, R>(
suspend fun exec(packedArgs: UByteArray, handler: suspend (A) -> R): UByteArray = suspend fun exec(packedArgs: UByteArray, handler: suspend (A) -> R): UByteArray =
BipackEncoder.encode( BipackEncoder.encode(
resultSerializer, resultSerializer,
handler(BipackDecoder.decode(packedArgs.toDataSource(), argsSerializer)) handler(
BipackDecoder.decode(packedArgs.toDataSource(), argsSerializer))
).toUByteArray() ).toUByteArray()
companion object { companion object {

View File

@ -56,6 +56,7 @@ class KiloClient<S>(
delay(1000) delay(1000)
} catch (_: CancellationException) { } catch (_: CancellationException) {
debug { "cancelled" } debug { "cancelled" }
break
} catch (t: Throwable) { } catch (t: Throwable) {
exception { "unexpected exception" to t } exception { "unexpected exception" to t }
delay(1000) delay(1000)

View File

@ -1,8 +1,15 @@
package net.sergeych.kiloparsec package net.sergeych.kiloparsec
import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.CompletableDeferred
import net.sergeych.mp_logger.LogTag
import net.sergeych.mp_logger.Loggable
import net.sergeych.mp_logger.debug
import net.sergeych.mp_logger.info
import net.sergeych.tools.AtomicCounter
import net.sergeych.utools.pack import net.sergeych.utools.pack
private val idCounter = AtomicCounter(0)
/** /**
* This class is not normally used directly. This is a local interface that supports * This class is not normally used directly. This is a local interface that supports
* secure transport command layer (encrypted calls/results) to work with [KiloRemoteInterface]. * secure transport command layer (encrypted calls/results) to work with [KiloRemoteInterface].
@ -12,7 +19,7 @@ import net.sergeych.utools.pack
internal class KiloL0Interface<T>( internal class KiloL0Interface<T>(
private val clientInterface: LocalInterface<KiloScope<T>>, private val clientInterface: LocalInterface<KiloScope<T>>,
private val deferredParams: CompletableDeferred<KiloParams<T>>, private val deferredParams: CompletableDeferred<KiloParams<T>>,
): LocalInterface<Unit>() { ) : LocalInterface<Unit>(), Loggable by LogTag("KL0:${idCounter.incrementAndGet()}") {
init { init {
// local interface uses the same session as a client: // local interface uses the same session as a client:
addErrorProvider(clientInterface) addErrorProvider(clientInterface)
@ -27,7 +34,10 @@ internal class KiloL0Interface<T>(
0u, 0u,
clientInterface.execute(params.scope, call.name, call.serializedArgs) clientInterface.execute(params.scope, call.name, call.serializedArgs)
) )
} catch (t: Throwable) { } catch(t: RemoteInterface.ClosedException) {
throw t
}
catch (t: Throwable) {
clientInterface.encodeError(0u, t) clientInterface.encodeError(0u, t)
} }
params.encrypt(pack(result)) params.encrypt(pack(result))

View File

@ -36,7 +36,7 @@ data class KiloParams<S>(
val sessionKeyPair: KeyExchangeSessionKeyPair, val sessionKeyPair: KeyExchangeSessionKeyPair,
val scopeSession: S, val scopeSession: S,
val remoteIdentity: SigningKey.Public?, val remoteIdentity: SigningKey.Public?,
val remoteTransport: RemoteInterface val remoteTransport: RemoteInterface,
) { ) {
@Serializable @Serializable
data class Package( data class Package(

View File

@ -32,6 +32,9 @@ class KiloServer<S>(
} }
catch(_: CancellationException) { catch(_: CancellationException) {
} }
catch(_: RemoteInterface.ClosedException) {
info { "Closed exception caught, closing" }
}
catch (t: Throwable) { catch (t: Throwable) {
exception { "unexpected while creating kiloclient" to t } exception { "unexpected while creating kiloclient" to t }
} }

View File

@ -59,7 +59,6 @@ class KiloServerConnection<S>(
null, null,
this@KiloServerConnection this@KiloServerConnection
) )
Handshake(1u, pair.publicKey, serverSigningKey?.seal(params!!.token)) Handshake(1u, pair.publicKey, serverSigningKey?.seal(params!!.token))
} }

View File

@ -1,11 +1,17 @@
package net.sergeych.kiloparsec package net.sergeych.kiloparsec
import net.sergeych.mp_logger.LogTag
import net.sergeych.mp_logger.Loggable
import net.sergeych.mp_logger.info
import net.sergeych.tools.AtomicCounter
import net.sergeych.utools.firstNonNull import net.sergeych.utools.firstNonNull
import kotlin.reflect.KClass import kotlin.reflect.KClass
private typealias RawCommandHandler<C> = suspend (C, UByteArray) -> UByteArray private typealias RawCommandHandler<C> = suspend (C, UByteArray) -> UByteArray
open class LocalInterface<S> { private val idCounter = AtomicCounter()
open class LocalInterface<S>: Loggable by LogTag("LocalInterface${idCounter.incrementAndGet()}") {
private val commands = mutableMapOf<String, RawCommandHandler<S>>() private val commands = mutableMapOf<String, RawCommandHandler<S>>()
@ -72,14 +78,16 @@ open class LocalInterface<S> {
fun encodeError(forId: UInt, t: Throwable): Transport.Block.Error = fun encodeError(forId: UInt, t: Throwable): Transport.Block.Error =
getErrorCode(t)?.let { Transport.Block.Error(forId, it, t.message) } getErrorCode(t)?.let { Transport.Block.Error(forId, it, t.message) }
?: Transport.Block.Error(forId, "UnknownError", t.message) ?: Transport.Block.Error(forId, "UnknownError", "${t::class.simpleName}: ${t.message}")
open fun getErrorBuilder(code: String): ((String, UByteArray?) -> Throwable)? = open fun getErrorBuilder(code: String): ((String, UByteArray?) -> Throwable)? =
errorBuilder[code] ?: errorProviders.firstNonNull { it.getErrorBuilder(code) } errorBuilder[code] ?: errorProviders.firstNonNull { it.getErrorBuilder(code) }
fun decodeError(tbe: Transport.Block.Error): Throwable = fun decodeError(tbe: Transport.Block.Error): Throwable =
getErrorBuilder(tbe.code)?.invoke(tbe.message, tbe.extra) getErrorBuilder(tbe.code)?.invoke(tbe.message, tbe.extra)
?: RemoteInterface.RemoteException(tbe) ?: RemoteInterface.RemoteException(tbe).also {
info { "can't decode error ${tbe.code}: ${tbe.message}" }
}
fun decodeAndThrow(tbe: Transport.Block.Error): Nothing { fun decodeAndThrow(tbe: Transport.Block.Error): Nothing {
throw decodeError(tbe) throw decodeError(tbe)

View File

@ -1,6 +1,8 @@
package net.sergeych.kiloparsec package net.sergeych.kiloparsec
import kotlinx.coroutines.* import kotlinx.coroutines.*
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.channels.ClosedSendChannelException
import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.Mutex
@ -147,6 +149,12 @@ class Transport<S>(
debug { "awaiting incoming blocks" } debug { "awaiting incoming blocks" }
while (isActive && !isClosed) { while (isActive && !isClosed) {
try { try {
debug { "input step starting closed=$isClosed active=$isActive"}
if( isClosed ) {
info { "breaking transort loop on closed"}
break
}
device.input.receive().let { packed -> device.input.receive().let { packed ->
debug { "<<<\n${packed.toDump()}" } debug { "<<<\n${packed.toDump()}" }
val b = unpack<Block>(packed) val b = unpack<Block>(packed)
@ -178,9 +186,11 @@ class Transport<S>(
) )
) )
} catch (x: RemoteInterface.ClosedException) { } catch (x: RemoteInterface.ClosedException) {
// strange case: handler throws closed? // handler forced close
error { "not supported: command handler for $b has thrown ClosedException" } warning { "handler requested closing of the connection"}
send(Block.Error(b.id, "UnexpectedException", x.message)) isClosed = true
runCatching { device.close() }
throw x
} catch (x: RemoteInterface.RemoteException) { } catch (x: RemoteInterface.RemoteException) {
send(Block.Error(b.id, x.code, x.text, x.extra)) send(Block.Error(b.id, x.code, x.text, x.extra))
} catch (t: Throwable) { } catch (t: Throwable) {
@ -189,19 +199,34 @@ class Transport<S>(
.also { debug { "command executed: ${b.name}" } } .also { debug { "command executed: ${b.name}" } }
} }
} }
debug { "=---------------------------------------------"}
} }
} catch (_: CancellationException) { debug { "input step performed closed=$isClosed active=$isActive"}
} catch (_: ClosedSendChannelException) {
info { "closed send channel" }
isClosed = true
} catch (_: ClosedReceiveChannelException) {
info { "closed receive channel"}
isClosed = true
}
catch (_: CancellationException) {
info { "loop is cancelled" } info { "loop is cancelled" }
isClosed = true isClosed = true
} catch( _: RemoteInterface.ClosedException) {
debug { "git closed exception here, ignoring" }
isClosed = true
} catch (t: Throwable) { } catch (t: Throwable) {
exception { "channel closed on error" to t } exception { "channel closed on error" to t }
info { "isa? $isActive / $isClosed" } info { "isa? $isActive / $isClosed" }
runCatching { device.close() }
isClosed = true isClosed = true
} }
} }
debug { "leaving transport loop" }
access.withLock { access.withLock {
debug { "access lock obtained"}
isClosed = true isClosed = true
debug { "closgin device $device" }
runCatching { device.close() }
for (c in calls.values) c.completeExceptionally(RemoteInterface.ClosedException()) for (c in calls.values) c.completeExceptionally(RemoteInterface.ClosedException())
calls.clear() calls.clear()
} }
@ -211,7 +236,12 @@ class Transport<S>(
} }
private suspend fun send(block: Block) { private suspend fun send(block: Block) {
device.output.send(pack(block)) try {
device.output.send(pack(block))
}
catch(_: ClosedSendChannelException) {
throw RemoteInterface.ClosedException()
}
} }
} }

View File

@ -3,6 +3,7 @@ package net.sergeych.kiloparsec.adapter
import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.channels.SendChannel
import net.sergeych.kiloparsec.RemoteInterface
import net.sergeych.kiloparsec.Transport import net.sergeych.kiloparsec.Transport
import net.sergeych.tools.AtomicCounter import net.sergeych.tools.AtomicCounter
@ -10,7 +11,7 @@ private val counter = AtomicCounter()
open class ProxyDevice( open class ProxyDevice(
inputChannel: Channel<UByteArray>, inputChannel: Channel<UByteArray>,
outputChannel: Channel<UByteArray>, outputChannel: Channel<UByteArray>,
private val onClose: suspend ()->Unit = {}): Transport.Device { private val onClose: suspend ()->Unit = { throw RemoteInterface.ClosedException() }): Transport.Device {
override val input: ReceiveChannel<UByteArray> = inputChannel override val input: ReceiveChannel<UByteArray> = inputChannel
override val output: SendChannel<UByteArray> = outputChannel override val output: SendChannel<UByteArray> = outputChannel

View File

@ -7,8 +7,12 @@ import io.ktor.websocket.*
import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.cancel import kotlinx.coroutines.cancel
import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.channels.ClosedSendChannelException
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import net.sergeych.crypto2.SigningKey import net.sergeych.crypto2.SigningKey
import net.sergeych.crypto2.toDump
import net.sergeych.kiloparsec.KiloClient import net.sergeych.kiloparsec.KiloClient
import net.sergeych.kiloparsec.KiloConnectionData import net.sergeych.kiloparsec.KiloConnectionData
import net.sergeych.kiloparsec.KiloInterface import net.sergeych.kiloparsec.KiloInterface
@ -21,7 +25,7 @@ import net.sergeych.tools.AtomicCounter
private val counter = AtomicCounter() private val counter = AtomicCounter()
fun <S>websocketClient( fun <S> websocketClient(
path: String, path: String,
clientInterface: KiloInterface<S> = KiloInterface(), clientInterface: KiloInterface<S> = KiloInterface(),
client: HttpClient = HttpClient { install(WebSockets) }, client: HttpClient = HttpClient { install(WebSockets) },
@ -48,36 +52,42 @@ fun <S>websocketClient(
url.port = u.port url.port = u.port
url.encodedPath = u.encodedPath url.encodedPath = u.encodedPath
url.parameters.appendAll(u.parameters) url.parameters.appendAll(u.parameters)
log.info { "kiloparsec server URL: $url" } log.info { "kiloparsec server URL: $url" }
}) { }) {
try { try {
log.info { "connected to the server" } log.info { "connected to the server" }
println("SENDING!!!") println("SENDING!!!")
send("Helluva") send("Helluva")
launch { launch {
for (block in output) { try {
send(block.toByteArray()) for (block in output) {
send(block.toByteArray())
}
log.info { "input is closed, closing the websocket" }
} catch (_: ClosedSendChannelException) {
log.info { "send channel closed" }
} }
log.info { "input is closed, closing the websocket" }
cancel() cancel()
} }
for (f in incoming) { for (f in incoming) {
if (f is Frame.Binary) { if (f is Frame.Binary) {
input.send(f.readBytes().toUByteArray()) input.send(f.readBytes().toUByteArray().also {
println("incoming\n${it.toDump()}")
})
} else { } else {
log.warning { "ignoring unexpected frame of type ${f.frameType}" } log.warning { "ignoring unexpected frame of type ${f.frameType}" }
} }
} }
} } catch (_: CancellationException) {
catch(_:CancellationException) { } catch( _: ClosedReceiveChannelException) {
} log.warning { "receive channel closed unexpectedly" }
catch(t: Throwable) { } catch (t: Throwable) {
log.exception { "unexpected error" to t } log.exception { "unexpected error" to t }
} }
log.info { "closing connection" } log.info { "closing connection" }
} }
} }
val device = ProxyDevice(input,output) { val device = ProxyDevice(input, output) {
input.close() input.close()
// we need to explicitly close the coroutine job, or it can hang for a long time // we need to explicitly close the coroutine job, or it can hang for a long time
// leaking resources. // leaking resources.

View File

@ -8,6 +8,7 @@ inline fun <reified T: Throwable>assertThrows(f: ()->Unit): T {
} }
catch(x: Throwable) { catch(x: Throwable) {
if( x is T ) return x if( x is T ) return x
fail("expected to throw $name but instead threw ${x::class.simpleName}: $x") println("expected to throw $name but instead threw ${x::class.simpleName}: $x\b\n${x.stackTraceToString()}")
fail("expected to throw $name but instead threw ${x::class.simpleName}: $x\b\n${x.stackTraceToString()}")
} }
} }

View File

@ -6,15 +6,16 @@ import io.ktor.server.websocket.*
import io.ktor.websocket.* import io.ktor.websocket.*
import kotlinx.coroutines.cancel import kotlinx.coroutines.cancel
import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.channels.ClosedSendChannelException
import kotlinx.coroutines.isActive import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import net.sergeych.crypto2.SigningKey import net.sergeych.crypto2.SigningKey
import net.sergeych.crypto2.toDump import net.sergeych.crypto2.toDump
import net.sergeych.kiloparsec.KiloInterface import net.sergeych.kiloparsec.KiloInterface
import net.sergeych.kiloparsec.KiloServerConnection import net.sergeych.kiloparsec.KiloServerConnection
import net.sergeych.mp_logger.LogTag import net.sergeych.kiloparsec.RemoteInterface
import net.sergeych.mp_logger.debug import net.sergeych.mp_logger.*
import net.sergeych.mp_logger.warning
import net.sergeych.tools.AtomicCounter import net.sergeych.tools.AtomicCounter
import java.time.Duration import java.time.Duration
@ -49,23 +50,36 @@ fun <S> Application.setupWebsocketServer(
} }
val server = KiloServerConnection( val server = KiloServerConnection(
localInterface, localInterface,
ProxyDevice(input, output) { input.close() }, ProxyDevice(input, output),
createSession(), createSession(),
serverKey serverKey
) )
launch { server.run() } launch { server.run() }
log.debug { "KSC started, looking for incoming frames" } log.debug { "KSC started, looking for incoming frames" }
for( f in incoming) { for (f in incoming) {
log.debug { "incoming frame: ${f.frameType}" } log.debug { "incoming frame: ${f.frameType}" }
if (f is Frame.Binary) if (f is Frame.Binary)
input.send(f.readBytes().toUByteArray().also { try {
log.debug { "in frame\n${it.toDump()}" } input.send(f.readBytes().toUByteArray().also {
}) log.debug { "in frame\n${it.toDump()}" }
})
} catch (_: RemoteInterface.ClosedException) {
log.info { "caught local closed exception, closing" }
break
} catch (_: ClosedReceiveChannelException) {
log.info { "receive channel is closed, closing connection" }
break
} catch (t: Throwable) {
log.exception { "unexpected exception, server connection will close" to t }
break
}
else else
log.warning { "unknown frame type ${f.frameType}, ignoring" } log.warning { "unknown frame type ${f.frameType}, ignoring" }
} }
log.debug { "closing the server" } log.debug { "closing the server" }
println("****************prec")
cancel() cancel()
println("****************postc")
} }
} }
} }

View File

@ -1,7 +1,11 @@
package net.sergeych.kiloparsec package net.sergeych.kiloparsec
import assertThrows
import io.ktor.server.engine.* import io.ktor.server.engine.*
import io.ktor.server.netty.* import io.ktor.server.netty.*
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.take
import kotlinx.coroutines.launch
import kotlinx.coroutines.test.runTest import kotlinx.coroutines.test.runTest
import net.sergeych.crypto2.initCrypto import net.sergeych.crypto2.initCrypto
import net.sergeych.kiloparsec.adapter.acceptTcpDevice import net.sergeych.kiloparsec.adapter.acceptTcpDevice
@ -10,9 +14,7 @@ import net.sergeych.kiloparsec.adapter.setupWebsocketServer
import net.sergeych.kiloparsec.adapter.websocketClient import net.sergeych.kiloparsec.adapter.websocketClient
import net.sergeych.mp_logger.Log import net.sergeych.mp_logger.Log
import java.net.InetAddress import java.net.InetAddress
import kotlin.test.Test import kotlin.test.*
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class ClientTest { class ClientTest {
@ -57,6 +59,7 @@ class ClientTest {
fun webSocketTest() = runTest { fun webSocketTest() = runTest {
initCrypto() initCrypto()
// fun Application. // fun Application.
val cmdClose by command<Unit,Unit>()
val cmdGetFoo by command<Unit,String>() val cmdGetFoo by command<Unit,String>()
val cmdSetFoo by command<String,Unit>() val cmdSetFoo by command<String,Unit>()
val cmdCheckConnected by command<Unit,Boolean>() val cmdCheckConnected by command<Unit,Boolean>()
@ -64,12 +67,23 @@ class ClientTest {
Log.connectConsole(Log.Level.DEBUG) Log.connectConsole(Log.Level.DEBUG)
data class Session(var foo: String="not set") data class Session(var foo: String="not set")
var closeCounter = 0
val serverInterface = KiloInterface<Session>().apply { val serverInterface = KiloInterface<Session>().apply {
var connectedCalled = false var connectedCalled = false
onConnected { connectedCalled = true } onConnected { connectedCalled = true }
on(cmdGetFoo) { session.foo } on(cmdGetFoo) { session.foo }
on(cmdSetFoo) { session.foo = it } on(cmdSetFoo) { session.foo = it }
on(cmdCheckConnected) { connectedCalled } on(cmdCheckConnected) { connectedCalled }
on(cmdClose) {
throw RemoteInterface.ClosedException()
// if( closeCounter < 2 ) {
// println("-------------------------- call close!")
// throw RemoteInterface.ClosedException()
// }
// else {
// println("close counter $closeCounter, ignoring")
// }
}
} }
// val server = setupWebsoketServer() // val server = setupWebsoketServer()
val ns: NettyApplicationEngine = embeddedServer(Netty, port = 8080, host = "0.0.0.0", module = { val ns: NettyApplicationEngine = embeddedServer(Netty, port = 8080, host = "0.0.0.0", module = {
@ -77,6 +91,14 @@ class ClientTest {
}).start(wait = false) }).start(wait = false)
val client = websocketClient<Unit>("ws://localhost:8080/kp") val client = websocketClient<Unit>("ws://localhost:8080/kp")
val states = mutableListOf<Boolean>()
val collector = launch {
client.state.collect {
println("got: $closeCounter/$it")
states += it
if( !it) { closeCounter++ }
}
}
println(1) println(1)
assertEquals(true, client.call(cmdCheckConnected)) assertEquals(true, client.call(cmdCheckConnected))
assertTrue { client.state.value } assertTrue { client.state.value }
@ -87,9 +109,16 @@ class ClientTest {
println(4) println(4)
assertEquals("foo", client.call(cmdGetFoo)) assertEquals("foo", client.call(cmdGetFoo))
println(5) println(5)
assertThrows<RemoteInterface.ClosedException> {
client.call(cmdClose)
}
println("0------------------------------------------------------------------------------connection should be closed")
// assertFalse { client.state.value }
// assertEquals("foo", client.call(cmdGetFoo))
client.close() client.close()
ns.stop() ns.stop()
collector.cancel()
println("----= states: $states")
println("stopped server") println("stopped server")
println("closed client") println("closed client")
} }