From f02b390ed4367d099541c7c3872a69b795371f47 Mon Sep 17 00:00:00 2001 From: sergeych Date: Wed, 15 Nov 2023 11:47:56 +0300 Subject: [PATCH] tcp optimization and graceful close --- .../kiloparsec/adapter/InetTransportDevice.kt | 8 ++- .../kiloparsec/adapter/ProxyDevice.kt | 2 +- .../kotlin/net/sergeych/tools/flow_tools.kt | 20 +++++++ .../kiloparsec/adapter/asyncSocketToDevice.kt | 60 +++++++++++++++---- .../kiloparsec/adapters/NetworkTest.kt | 25 +++++--- 5 files changed, 94 insertions(+), 21 deletions(-) create mode 100644 src/commonMain/kotlin/net/sergeych/tools/flow_tools.kt diff --git a/src/commonMain/kotlin/net/sergeych/kiloparsec/adapter/InetTransportDevice.kt b/src/commonMain/kotlin/net/sergeych/kiloparsec/adapter/InetTransportDevice.kt index 0ee855c..0110fbe 100644 --- a/src/commonMain/kotlin/net/sergeych/kiloparsec/adapter/InetTransportDevice.kt +++ b/src/commonMain/kotlin/net/sergeych/kiloparsec/adapter/InetTransportDevice.kt @@ -2,10 +2,14 @@ package net.sergeych.kiloparsec.adapter import kotlinx.coroutines.channels.Channel +/** + * Transport device for inet protocol family with graceful shutdown on [close] + */ @Suppress("unused") class InetTransportDevice( inputChannel: Channel, outputChannel: Channel, val remoteAddress: NetworkAddress, - onclose: ()->Unit = {} -) : ProxyDevice(inputChannel, outputChannel, onclose) \ No newline at end of file + val flush: suspend ()->Unit = {}, + doClose: suspend ()->Unit = {} +) : ProxyDevice(inputChannel, outputChannel, doClose) \ No newline at end of file diff --git a/src/commonMain/kotlin/net/sergeych/kiloparsec/adapter/ProxyDevice.kt b/src/commonMain/kotlin/net/sergeych/kiloparsec/adapter/ProxyDevice.kt index b4f7fe9..cf7364e 100644 --- a/src/commonMain/kotlin/net/sergeych/kiloparsec/adapter/ProxyDevice.kt +++ b/src/commonMain/kotlin/net/sergeych/kiloparsec/adapter/ProxyDevice.kt @@ -8,7 +8,7 @@ import net.sergeych.kiloparsec.Transport open class ProxyDevice( inputChannel: Channel, outputChannel: Channel, - private val onClose: ()->Unit = {}): Transport.Device { + private val onClose: suspend ()->Unit = {}): Transport.Device { override val input: ReceiveChannel = inputChannel override val output: SendChannel = outputChannel diff --git a/src/commonMain/kotlin/net/sergeych/tools/flow_tools.kt b/src/commonMain/kotlin/net/sergeych/tools/flow_tools.kt new file mode 100644 index 0000000..6a01325 --- /dev/null +++ b/src/commonMain/kotlin/net/sergeych/tools/flow_tools.kt @@ -0,0 +1,20 @@ +package net.sergeych.tools + +import kotlinx.coroutines.cancel +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.launch + +/** + * suspend until the flow produces the value to which the + * predicate returns true + */ +suspend fun Flow.waitFor(predicate: (T)->Boolean) { + coroutineScope { + launch { + collect { + if( predicate(it) ) cancel() + } + } + } +} \ No newline at end of file diff --git a/src/jvmMain/kotlin/net/sergeych/kiloparsec/adapter/asyncSocketToDevice.kt b/src/jvmMain/kotlin/net/sergeych/kiloparsec/adapter/asyncSocketToDevice.kt index 17903e4..5c7d838 100644 --- a/src/jvmMain/kotlin/net/sergeych/kiloparsec/adapter/asyncSocketToDevice.kt +++ b/src/jvmMain/kotlin/net/sergeych/kiloparsec/adapter/asyncSocketToDevice.kt @@ -3,13 +3,14 @@ package net.sergeych.kiloparsec.adapter import kotlinx.coroutines.* import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.ClosedReceiveChannelException -import net.sergeych.bintools.toDump +import kotlinx.coroutines.flow.MutableStateFlow import net.sergeych.crypto.encodeVarUnsigned import net.sergeych.crypto.readVarUnsigned import net.sergeych.kiloparsec.Transport import net.sergeych.mp_logger.LogTag import net.sergeych.mp_logger.warning import net.sergeych.mp_tools.globalLaunch +import net.sergeych.tools.waitFor import java.net.InetSocketAddress import java.net.StandardSocketOptions.TCP_NODELAY import java.nio.ByteBuffer @@ -22,8 +23,7 @@ private val log = LogTag("ASTD") /** * Prepend block with its size, varint-encoded */ -private fun encode(block: UByteArray): ByteArray - = (encodeVarUnsigned(block.size.toUInt()) + block).toByteArray() +private fun encode(block: UByteArray): ByteArray = (encodeVarUnsigned(block.size.toUInt()) + block).toByteArray() /** * Convert asynchronous socket to a [Transport.Device] using non-blocking nio, @@ -35,14 +35,19 @@ suspend fun asyncSocketToDevice(socket: AsynchronousSocketChannel): InetTranspor val deferredDevice = CompletableDeferred() globalLaunch { coroutineScope { + val sendQueueEmpty = MutableStateFlow(true) + val receiving = MutableStateFlow(false) fun stop() { cancel() } + // We're in block mode, every block we send worth immediate sending, we do not + // send partial blocks, so: socket.setOption(TCP_NODELAY, true) + // socket input is to be parsed for blocks, so we receive bytes // and decode them to blocks val input = Channel(1024) - // copy from socket to input: + // copy incoming data from the socket to input channel: launch { val data = ByteArray(1024) val inb = ByteBuffer.wrap(data) @@ -53,32 +58,41 @@ suspend fun asyncSocketToDevice(socket: AsynchronousSocketChannel): InetTranspor } if (size < 0) stop() else { - println("recvd:\n${data.sliceArray(0..() // copy from output to socket: launch { try { while (isActive) { // wait for the first block to send + sendQueueEmpty.value = outputBlocks.isEmpty var data = encode(outputBlocks.receive()) + + // now we're sending, so queue state is sending: + sendQueueEmpty.value = false + // if there are more, take them all (NO_DELAY optimization) while (!outputBlocks.isEmpty) data += encode(outputBlocks.receive()) - // now send the aggregate: + + // now send it all together: val outBuff = ByteBuffer.wrap(data) val cnt = suspendCoroutine { continuation -> socket.write(outBuff, continuation, IntCompletionHandler) } // be sure it was all sent - if( outBuff.position() != data.size || cnt != data.size) { - throw RuntimeException("PArtial write!") + if (outBuff.position() != data.size || cnt != data.size) { + throw RuntimeException("unexpected partial write") } } + // in the case of just breaking out of the loop: + sendQueueEmpty.value = true } catch (_: ClosedReceiveChannelException) { stop() } @@ -89,7 +103,9 @@ suspend fun asyncSocketToDevice(socket: AsynchronousSocketChannel): InetTranspor launch { try { while (isActive) { + receiving.value = !input.isEmpty val size = readVarUnsigned(input) + receiving.value = true if (size == 0u) log.warning { "zero size block is ignored!" } else { val block = UByteArray(size.toInt()) @@ -105,11 +121,33 @@ suspend fun asyncSocketToDevice(socket: AsynchronousSocketChannel): InetTranspor inputBlocks.send(null) stop() } + receiving.value = false } -// SocketAddress. + + // wait until send queue is empty + suspend fun flush() { + yield() + // do not slow down with collect if it is ok by now: + if (!sendQueueEmpty.value || !outputBlocks.isEmpty) + // wait until all output is sent + sendQueueEmpty.waitFor { it && outputBlocks.isEmpty } + } + val addr = socket.remoteAddress as InetSocketAddress deferredDevice.complete( - InetTransportDevice(inputBlocks, outputBlocks, JvmNetworkAddress(addr.address,addr.port)) { stop() } + InetTransportDevice(inputBlocks, outputBlocks, JvmNetworkAddress(addr.address, addr.port), + { flush() } + ) { + yield() + // wait until all received data are parsed, but not too long + withTimeoutOrNull( 1000 ) { + receiving.waitFor { !it } + } + // graceful close: flush output + flush() + // then stop it + stop() + } ) } globalLaunch { socket.close() } diff --git a/src/jvmTest/kotlin/net/sergeych/kiloparsec/adapters/NetworkTest.kt b/src/jvmTest/kotlin/net/sergeych/kiloparsec/adapters/NetworkTest.kt index a2183e4..00e7d75 100644 --- a/src/jvmTest/kotlin/net/sergeych/kiloparsec/adapters/NetworkTest.kt +++ b/src/jvmTest/kotlin/net/sergeych/kiloparsec/adapters/NetworkTest.kt @@ -9,8 +9,10 @@ import net.sergeych.kiloparsec.adapter.acceptTcpDevice import net.sergeych.kiloparsec.adapter.connectTcpDevice import net.sergeych.kiloparsec.adapter.toNetworkAddress import net.sergeych.mp_logger.Log +import net.sergeych.tools.ProtectedOp import org.junit.jupiter.api.Assertions.assertEquals import kotlin.test.Test +import kotlin.test.assertContains class NetworkTest { @@ -35,6 +37,8 @@ class NetworkTest { coroutineScope { val serverFlow = acceptTcpDevice(17171) + val op = ProtectedOp() + var pills = setOf() val j = launch { serverFlow.collect { device -> launch { @@ -42,13 +46,14 @@ class NetworkTest { device.output.send("Great".encodeToUByteArray()) while (true) { val x = device.input.receive()?.decodeFromUByteArray() ?: break - if (x == "Goodbye") break - if (x == "die") { - println("collector get poisoned pill") + if (x.startsWith("die")) { + op { + pills += x + } cancel() - break } - println("ignoring unexpected input: $x") + else + println("ignoring unexpected input: $x") } } } @@ -59,14 +64,20 @@ class NetworkTest { assertEquals("Hello, world!", s.input.receive()!!.decodeFromUByteArray()) assertEquals("Great", s.input.receive()!!.decodeFromUByteArray()) s.output.send("Goodbye".encodeToUByteArray()) + s.output.send("die1".encodeToUByteArray()) s.close() } val s1 = connectTcpDevice("127.0.1.1:17171".toNetworkAddress()) assertEquals("Hello, world!", s1.input.receive()!!.decodeFromUByteArray()) assertEquals("Great", s1.input.receive()!!.decodeFromUByteArray()) - s1.output.send("die".encodeToUByteArray()) - delay(200) + s1.output.send("die2".encodeToUByteArray()) s1.close() + + // check that channels were flushed prior to closed: + assertContains(pills, "die1") + assertContains(pills, "die2") + + // Check that server jobs are closed j.cancelAndJoin() } }