tcp optimization and graceful close

This commit is contained in:
Sergey Chernov 2023-11-15 11:47:56 +03:00
parent f92431a281
commit f02b390ed4
5 changed files with 94 additions and 21 deletions

View File

@ -2,10 +2,14 @@ package net.sergeych.kiloparsec.adapter
import kotlinx.coroutines.channels.Channel
/**
* Transport device for inet protocol family with graceful shutdown on [close]
*/
@Suppress("unused")
class InetTransportDevice(
inputChannel: Channel<UByteArray?>,
outputChannel: Channel<UByteArray>,
val remoteAddress: NetworkAddress,
onclose: ()->Unit = {}
) : ProxyDevice(inputChannel, outputChannel, onclose)
val flush: suspend ()->Unit = {},
doClose: suspend ()->Unit = {}
) : ProxyDevice(inputChannel, outputChannel, doClose)

View File

@ -8,7 +8,7 @@ import net.sergeych.kiloparsec.Transport
open class ProxyDevice(
inputChannel: Channel<UByteArray?>,
outputChannel: Channel<UByteArray>,
private val onClose: ()->Unit = {}): Transport.Device {
private val onClose: suspend ()->Unit = {}): Transport.Device {
override val input: ReceiveChannel<UByteArray?> = inputChannel
override val output: SendChannel<UByteArray> = outputChannel

View File

@ -0,0 +1,20 @@
package net.sergeych.tools
import kotlinx.coroutines.cancel
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.launch
/**
* suspend until the flow produces the value to which the
* predicate returns true
*/
suspend fun <T>Flow<T>.waitFor(predicate: (T)->Boolean) {
coroutineScope {
launch {
collect {
if( predicate(it) ) cancel()
}
}
}
}

View File

@ -3,13 +3,14 @@ package net.sergeych.kiloparsec.adapter
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import net.sergeych.bintools.toDump
import kotlinx.coroutines.flow.MutableStateFlow
import net.sergeych.crypto.encodeVarUnsigned
import net.sergeych.crypto.readVarUnsigned
import net.sergeych.kiloparsec.Transport
import net.sergeych.mp_logger.LogTag
import net.sergeych.mp_logger.warning
import net.sergeych.mp_tools.globalLaunch
import net.sergeych.tools.waitFor
import java.net.InetSocketAddress
import java.net.StandardSocketOptions.TCP_NODELAY
import java.nio.ByteBuffer
@ -22,8 +23,7 @@ private val log = LogTag("ASTD")
/**
* Prepend block with its size, varint-encoded
*/
private fun encode(block: UByteArray): ByteArray
= (encodeVarUnsigned(block.size.toUInt()) + block).toByteArray()
private fun encode(block: UByteArray): ByteArray = (encodeVarUnsigned(block.size.toUInt()) + block).toByteArray()
/**
* Convert asynchronous socket to a [Transport.Device] using non-blocking nio,
@ -35,14 +35,19 @@ suspend fun asyncSocketToDevice(socket: AsynchronousSocketChannel): InetTranspor
val deferredDevice = CompletableDeferred<InetTransportDevice>()
globalLaunch {
coroutineScope {
val sendQueueEmpty = MutableStateFlow(true)
val receiving = MutableStateFlow(false)
fun stop() {
cancel()
}
// We're in block mode, every block we send worth immediate sending, we do not
// send partial blocks, so:
socket.setOption(TCP_NODELAY, true)
// socket input is to be parsed for blocks, so we receive bytes
// and decode them to blocks
val input = Channel<UByte>(1024)
// copy from socket to input:
// copy incoming data from the socket to input channel:
launch {
val data = ByteArray(1024)
val inb = ByteBuffer.wrap(data)
@ -53,32 +58,41 @@ suspend fun asyncSocketToDevice(socket: AsynchronousSocketChannel): InetTranspor
}
if (size < 0) stop()
else {
println("recvd:\n${data.sliceArray(0..<size).toDump()}\n------------------")
// println("recvd:\n${data.sliceArray(0..<size).toDump()}\n------------------")
for (i in 0..<size) input.send(data[i].toUByte())
}
}
}
// output is blocks, so we sent blocks:
// output is blocks, so we sent transformed, framed blocks:
val outputBlocks = Channel<UByteArray>()
// copy from output to socket:
launch {
try {
while (isActive) {
// wait for the first block to send
sendQueueEmpty.value = outputBlocks.isEmpty
var data = encode(outputBlocks.receive())
// now we're sending, so queue state is sending:
sendQueueEmpty.value = false
// if there are more, take them all (NO_DELAY optimization)
while (!outputBlocks.isEmpty)
data += encode(outputBlocks.receive())
// now send the aggregate:
// now send it all together:
val outBuff = ByteBuffer.wrap(data)
val cnt = suspendCoroutine { continuation ->
socket.write(outBuff, continuation, IntCompletionHandler)
}
// be sure it was all sent
if( outBuff.position() != data.size || cnt != data.size) {
throw RuntimeException("PArtial write!")
if (outBuff.position() != data.size || cnt != data.size) {
throw RuntimeException("unexpected partial write")
}
}
// in the case of just breaking out of the loop:
sendQueueEmpty.value = true
} catch (_: ClosedReceiveChannelException) {
stop()
}
@ -89,7 +103,9 @@ suspend fun asyncSocketToDevice(socket: AsynchronousSocketChannel): InetTranspor
launch {
try {
while (isActive) {
receiving.value = !input.isEmpty
val size = readVarUnsigned(input)
receiving.value = true
if (size == 0u) log.warning { "zero size block is ignored!" }
else {
val block = UByteArray(size.toInt())
@ -105,11 +121,33 @@ suspend fun asyncSocketToDevice(socket: AsynchronousSocketChannel): InetTranspor
inputBlocks.send(null)
stop()
}
receiving.value = false
}
// SocketAddress.
// wait until send queue is empty
suspend fun flush() {
yield()
// do not slow down with collect if it is ok by now:
if (!sendQueueEmpty.value || !outputBlocks.isEmpty)
// wait until all output is sent
sendQueueEmpty.waitFor { it && outputBlocks.isEmpty }
}
val addr = socket.remoteAddress as InetSocketAddress
deferredDevice.complete(
InetTransportDevice(inputBlocks, outputBlocks, JvmNetworkAddress(addr.address,addr.port)) { stop() }
InetTransportDevice(inputBlocks, outputBlocks, JvmNetworkAddress(addr.address, addr.port),
{ flush() }
) {
yield()
// wait until all received data are parsed, but not too long
withTimeoutOrNull( 1000 ) {
receiving.waitFor { !it }
}
// graceful close: flush output
flush()
// then stop it
stop()
}
)
}
globalLaunch { socket.close() }

View File

@ -9,8 +9,10 @@ import net.sergeych.kiloparsec.adapter.acceptTcpDevice
import net.sergeych.kiloparsec.adapter.connectTcpDevice
import net.sergeych.kiloparsec.adapter.toNetworkAddress
import net.sergeych.mp_logger.Log
import net.sergeych.tools.ProtectedOp
import org.junit.jupiter.api.Assertions.assertEquals
import kotlin.test.Test
import kotlin.test.assertContains
class NetworkTest {
@ -35,6 +37,8 @@ class NetworkTest {
coroutineScope {
val serverFlow = acceptTcpDevice(17171)
val op = ProtectedOp()
var pills = setOf<String>()
val j = launch {
serverFlow.collect { device ->
launch {
@ -42,13 +46,14 @@ class NetworkTest {
device.output.send("Great".encodeToUByteArray())
while (true) {
val x = device.input.receive()?.decodeFromUByteArray() ?: break
if (x == "Goodbye") break
if (x == "die") {
println("collector get poisoned pill")
if (x.startsWith("die")) {
op {
pills += x
}
cancel()
break
}
println("ignoring unexpected input: $x")
else
println("ignoring unexpected input: $x")
}
}
}
@ -59,14 +64,20 @@ class NetworkTest {
assertEquals("Hello, world!", s.input.receive()!!.decodeFromUByteArray())
assertEquals("Great", s.input.receive()!!.decodeFromUByteArray())
s.output.send("Goodbye".encodeToUByteArray())
s.output.send("die1".encodeToUByteArray())
s.close()
}
val s1 = connectTcpDevice("127.0.1.1:17171".toNetworkAddress())
assertEquals("Hello, world!", s1.input.receive()!!.decodeFromUByteArray())
assertEquals("Great", s1.input.receive()!!.decodeFromUByteArray())
s1.output.send("die".encodeToUByteArray())
delay(200)
s1.output.send("die2".encodeToUByteArray())
s1.close()
// check that channels were flushed prior to closed:
assertContains(pills, "die1")
assertContains(pills, "die2")
// Check that server jobs are closed
j.cancelAndJoin()
}
}