fixed automatic reconnection in TCP client/server

This commit is contained in:
Sergey Chernov 2024-06-17 17:42:16 +07:00
parent 0d3a8ae95c
commit 825c0bd5f7
7 changed files with 95 additions and 49 deletions

View File

@ -6,6 +6,8 @@ import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.isActive import kotlinx.coroutines.isActive
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import net.sergeych.crypto2.SigningKey import net.sergeych.crypto2.SigningKey
import net.sergeych.mp_logger.LogTag import net.sergeych.mp_logger.LogTag
import net.sergeych.mp_logger.Loggable import net.sergeych.mp_logger.Loggable
@ -45,11 +47,12 @@ class KiloClient<S>(
debug { "getting connection" } debug { "getting connection" }
val kc = connectionDataFactory() val kc = connectionDataFactory()
debug { "get device and session" } debug { "get device and session" }
val client = KiloClientConnection(localInterface, kc,secretKey) val client = KiloClientConnection(localInterface, kc, secretKey)
deferredClient.complete(client) deferredClient.complete(client)
client.run { client.run {
_state.value = it _state.value = it
} }
resetDeferredClient()
debug { "client run finished" } debug { "client run finished" }
} catch (_: RemoteInterface.ClosedException) { } catch (_: RemoteInterface.ClosedException) {
debug { "remote closed" } debug { "remote closed" }
@ -62,9 +65,8 @@ class KiloClient<S>(
delay(1000) delay(1000)
} }
_state.value = false _state.value = false
if (deferredClient.isActive) resetDeferredClient()
deferredClient = CompletableDeferred() delay(100)
delay(1000)
} }
} }
@ -73,7 +75,23 @@ class KiloClient<S>(
debug { "client is closed" } debug { "client is closed" }
} }
override suspend fun <A, R> call(cmd: Command<A, R>, args: A): R = deferredClient.await().call(cmd, args) private val defMutex = Mutex()
private suspend fun resetDeferredClient() {
defMutex.withLock {
if (!deferredClient.isActive) {
deferredClient = CompletableDeferred()
}
}
}
override suspend fun <A, R> call(cmd: Command<A, R>, args: A): R =
try {
deferredClient.await().call(cmd, args)
} catch (t: RemoteInterface.ClosedException) {
resetDeferredClient()
throw t
}
/** /**
* Current session token. This is a per-connection unique random value same on the client and server part so * Current session token. This is a per-connection unique random value same on the client and server part so
@ -142,11 +160,11 @@ class KiloClient<S>(
internal fun build(): KiloClient<S> { internal fun build(): KiloClient<S> {
val i = KiloInterface<S>() val i = KiloInterface<S>()
for(ep in errorProviders) i.addErrorProvider(ep) for (ep in errorProviders) i.addErrorProvider(ep)
interfaceBuilder?.let { i.it() } interfaceBuilder?.let { i.it() }
val connector = connectionBuilder ?: throw IllegalArgumentException("connect handler was not set") val connector = connectionBuilder ?: throw IllegalArgumentException("connect handler was not set")
return KiloClient(i,secretIdKey) { return KiloClient(i, secretIdKey) {
KiloConnectionData(connector(),sessionBuilder()) KiloConnectionData(connector(), sessionBuilder())
} }
} }
} }

View File

@ -56,6 +56,8 @@ class Transport<S>(
* possible. This method must not throw exceptions. * possible. This method must not throw exceptions.
*/ */
suspend fun close() suspend fun close()
suspend fun flush() {}
} }
@Serializable(TransportBlockSerializer::class) @Serializable(TransportBlockSerializer::class)
@ -184,6 +186,7 @@ class Transport<S>(
// handler forced close // handler forced close
warning { "handler requested closing of the connection (${x.flushSendQueue}"} warning { "handler requested closing of the connection (${x.flushSendQueue}"}
isClosed = true isClosed = true
if( x.flushSendQueue ) device.flush()
device.close() device.close()
} catch (x: RemoteInterface.RemoteException) { } catch (x: RemoteInterface.RemoteException) {
send(Block.Error(b.id, x.code, x.text, x.extra)) send(Block.Error(b.id, x.code, x.text, x.extra))
@ -207,6 +210,11 @@ class Transport<S>(
info { "closing connection by local request ($cce)"} info { "closing connection by local request ($cce)"}
device.close() device.close()
} }
catch(t: RemoteInterface.ClosedException) {
// it is ok: we just exit the coroutine normally
// and mark we're closing
isClosed = true
}
catch (_: CancellationException) { catch (_: CancellationException) {
info { "loop is cancelled with CancellationException" } info { "loop is cancelled with CancellationException" }
isClosed = true isClosed = true

View File

@ -10,8 +10,8 @@ class InetTransportDevice(
inputChannel: Channel<UByteArray>, inputChannel: Channel<UByteArray>,
outputChannel: Channel<UByteArray>, outputChannel: Channel<UByteArray>,
val remoteAddress: NetworkAddress, val remoteAddress: NetworkAddress,
val flush: suspend ()->Unit = {}, doClose: (suspend ()->Unit)? = null,
doClose: suspend ()->Unit = {} doFlush: (suspend ()->Unit)? = null,
) : ProxyDevice(inputChannel, outputChannel, doClose) { ) : ProxyDevice(inputChannel, outputChannel, doClose, doFlush) {
override fun toString(): String = "@$remoteAddress" override fun toString(): String = "@$remoteAddress"
} }

View File

@ -3,20 +3,34 @@ package net.sergeych.kiloparsec.adapter
import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.channels.SendChannel
import net.sergeych.kiloparsec.RemoteInterface import kotlinx.coroutines.delay
import net.sergeych.kiloparsec.Transport import net.sergeych.kiloparsec.Transport
import net.sergeych.tools.AtomicCounter import net.sergeych.tools.AtomicCounter
private val counter = AtomicCounter() private val counter = AtomicCounter()
open class ProxyDevice( open class ProxyDevice(
inputChannel: Channel<UByteArray>, private val inputChannel: Channel<UByteArray>,
outputChannel: Channel<UByteArray>, private val outputChannel: Channel<UByteArray>,
private val onClose: suspend ()->Unit = { throw RemoteInterface.ClosedException() }): Transport.Device { private val doClose: (suspend ()->Unit)? = null,
private val doFlush: (suspend ()->Unit)? = null,
): Transport.Device {
override val input: ReceiveChannel<UByteArray> = inputChannel override val input: ReceiveChannel<UByteArray> = inputChannel
override val output: SendChannel<UByteArray> = outputChannel override val output: SendChannel<UByteArray> = outputChannel
override suspend fun close() { override suspend fun close() {
onClose() doClose?.invoke()
runCatching { inputChannel.close() }
runCatching { outputChannel.close() }
}
override suspend fun flush() {
doFlush?.invoke()
var cnt = 10
while(!outputChannel.isEmpty) {
if (cnt-- < 0) break
delay(50)
}
super.flush()
} }
private val id = counter.incrementAndGet() private val id = counter.incrementAndGet()

View File

@ -20,6 +20,7 @@ fun createTestDevice(): Pair<Transport.Device, Transport.Device> {
val d1 = object : Transport.Device { val d1 = object : Transport.Device {
override val input: ReceiveChannel<UByteArray> = p1 override val input: ReceiveChannel<UByteArray> = p1
override val output: SendChannel<UByteArray> = p2 override val output: SendChannel<UByteArray> = p2
override suspend fun close() { override suspend fun close() {
p2.close() p2.close()
} }

View File

@ -6,8 +6,10 @@ import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import net.sergeych.crypto2.encodeVarUnsigned import net.sergeych.crypto2.encodeVarUnsigned
import net.sergeych.crypto2.readVarUnsigned import net.sergeych.crypto2.readVarUnsigned
import net.sergeych.kiloparsec.RemoteInterface
import net.sergeych.kiloparsec.Transport import net.sergeych.kiloparsec.Transport
import net.sergeych.mp_logger.LogTag import net.sergeych.mp_logger.LogTag
import net.sergeych.mp_logger.info
import net.sergeych.mp_logger.warning import net.sergeych.mp_logger.warning
import net.sergeych.mp_tools.globalLaunch import net.sergeych.mp_tools.globalLaunch
import net.sergeych.tools.waitFor import net.sergeych.tools.waitFor
@ -37,9 +39,6 @@ suspend fun asyncSocketToDevice(socket: AsynchronousSocketChannel): InetTranspor
coroutineScope { coroutineScope {
val sendQueueEmpty = MutableStateFlow(true) val sendQueueEmpty = MutableStateFlow(true)
val receiving = MutableStateFlow(false) val receiving = MutableStateFlow(false)
fun stop() {
cancel()
}
// We're in block mode, every block we send worth immediate sending, we do not // We're in block mode, every block we send worth immediate sending, we do not
// send partial blocks, so: // send partial blocks, so:
socket.setOption(TCP_NODELAY, true) socket.setOption(TCP_NODELAY, true)
@ -47,25 +46,37 @@ suspend fun asyncSocketToDevice(socket: AsynchronousSocketChannel): InetTranspor
// socket input is to be parsed for blocks, so we receive bytes // socket input is to be parsed for blocks, so we receive bytes
// and decode them to blocks // and decode them to blocks
val input = Channel<UByte>(1024) val input = Channel<UByte>(1024)
val inputBlocks = Channel<UByteArray>()
// output is blocks, so we sent transformed, framed blocks:
val outputBlocks = Channel<UByteArray>()
fun stop() {
kotlin.runCatching { inputBlocks.close(RemoteInterface.ClosedException()) }
kotlin.runCatching { outputBlocks.close() }
socket.close()
cancel()
}
// copy incoming data from the socket to input channel: // copy incoming data from the socket to input channel:
launch { launch {
val data = ByteArray(1024) val data = ByteArray(1024)
val inb = ByteBuffer.wrap(data) val inb = ByteBuffer.wrap(data)
while (isActive) { kotlin.runCatching {
inb.position(0) while (isActive) {
val size: Int = suspendCoroutine { continuation -> inb.position(0)
socket.read(inb, continuation, IntCompletionHandler) val size: Int = suspendCoroutine { continuation ->
} socket.read(inb, continuation, IntCompletionHandler)
if (size < 0) stop() }
else { if (size < 0) stop()
else {
// println("recvd:\n${data.sliceArray(0..<size).toDump()}\n------------------") // println("recvd:\n${data.sliceArray(0..<size).toDump()}\n------------------")
for (i in 0..<size) input.send(data[i].toUByte()) for (i in 0..<size) input.send(data[i].toUByte())
}
} }
} }
} }
// output is blocks, so we sent transformed, framed blocks:
val outputBlocks = Channel<UByteArray>()
// copy from output to socket: // copy from output to socket:
launch { launch {
try { try {
@ -98,7 +109,6 @@ suspend fun asyncSocketToDevice(socket: AsynchronousSocketChannel): InetTranspor
} }
} }
// transport device copes with blocks: // transport device copes with blocks:
val inputBlocks = Channel<UByteArray>()
// decode blocks from a byte channel read from the socket: // decode blocks from a byte channel read from the socket:
launch { launch {
try { try {
@ -122,30 +132,21 @@ suspend fun asyncSocketToDevice(socket: AsynchronousSocketChannel): InetTranspor
receiving.value = false receiving.value = false
} }
// wait until send queue is empty
suspend fun flush() {
yield()
// do not slow down with collect if it is ok by now:
if (!sendQueueEmpty.value || !outputBlocks.isEmpty)
// wait until all output is sent
sendQueueEmpty.waitFor { it && outputBlocks.isEmpty }
}
val addr = socket.remoteAddress as InetSocketAddress val addr = socket.remoteAddress as InetSocketAddress
deferredDevice.complete( deferredDevice.complete(
InetTransportDevice(inputBlocks, outputBlocks, JvmNetworkAddress(addr.address, addr.port), InetTransportDevice(inputBlocks, outputBlocks, JvmNetworkAddress(addr.address, addr.port), {
{ flush() } val log = LogTag("S:${addr.address}:${addr.port}")
) { log.info { "ASTD is waitig to close" }
yield() yield()
// wait until all received data are parsed, but not too long // wait until all received data are parsed, but not too long
withTimeoutOrNull( 1000 ) { withTimeoutOrNull(500) {
receiving.waitFor { !it } receiving.waitFor { !it }
} }
// graceful close: flush output
flush()
// then stop it // then stop it
log.info { "ASTd is calling STOP" }
stop() stop()
} log.info { "STopped" }
})
) )
} }
globalLaunch { socket.close() } globalLaunch { socket.close() }

View File

@ -1,5 +1,6 @@
package net.sergeych.kiloparsec package net.sergeych.kiloparsec
import assertThrows
import io.ktor.server.engine.* import io.ktor.server.engine.*
import io.ktor.server.netty.* import io.ktor.server.netty.*
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@ -45,14 +46,13 @@ class ClientTest {
onConnected { session.data = "start" } onConnected { session.data = "start" }
on(cmdSave) { session.data = it } on(cmdSave) { session.data = it }
on(cmdLoad) { on(cmdLoad) {
println("load!")
session.data session.data
} }
on(cmdException) { on(cmdException) {
throw TestException() throw TestException()
} }
on(cmdDrop) { on(cmdDrop) {
throw RemoteInterface.ClosedException() throw LocalInterface.BreakConnectionException()
} }
} }
val server = KiloServer(cli, acceptTcpDevice(17101)) { val server = KiloServer(cli, acceptTcpDevice(17101)) {
@ -69,12 +69,16 @@ class ClientTest {
client.call(cmdSave, "foobar") client.call(cmdSave, "foobar")
assertEquals("foobar", client.call(cmdLoad)) assertEquals("foobar", client.call(cmdLoad))
// client.call(cmdException)
val res = kotlin.runCatching { client.call(cmdException) } val res = kotlin.runCatching { client.call(cmdException) }
println(res.exceptionOrNull()) println(res.exceptionOrNull())
assertIs<TestException>(res.exceptionOrNull()) assertIs<TestException>(res.exceptionOrNull())
assertEquals("foobar", client.call(cmdLoad)) assertEquals("foobar", client.call(cmdLoad))
assertThrows<RemoteInterface.ClosedException> { client.call(cmdDrop) }
// reconnect?
assertEquals("start", client.call(cmdLoad))
server.close() server.close()
} }