tcp optimization and graceful close
This commit is contained in:
parent
f92431a281
commit
f02b390ed4
@ -2,10 +2,14 @@ package net.sergeych.kiloparsec.adapter
|
||||
|
||||
import kotlinx.coroutines.channels.Channel
|
||||
|
||||
/**
|
||||
* Transport device for inet protocol family with graceful shutdown on [close]
|
||||
*/
|
||||
@Suppress("unused")
|
||||
class InetTransportDevice(
|
||||
inputChannel: Channel<UByteArray?>,
|
||||
outputChannel: Channel<UByteArray>,
|
||||
val remoteAddress: NetworkAddress,
|
||||
onclose: ()->Unit = {}
|
||||
) : ProxyDevice(inputChannel, outputChannel, onclose)
|
||||
val flush: suspend ()->Unit = {},
|
||||
doClose: suspend ()->Unit = {}
|
||||
) : ProxyDevice(inputChannel, outputChannel, doClose)
|
@ -8,7 +8,7 @@ import net.sergeych.kiloparsec.Transport
|
||||
open class ProxyDevice(
|
||||
inputChannel: Channel<UByteArray?>,
|
||||
outputChannel: Channel<UByteArray>,
|
||||
private val onClose: ()->Unit = {}): Transport.Device {
|
||||
private val onClose: suspend ()->Unit = {}): Transport.Device {
|
||||
|
||||
override val input: ReceiveChannel<UByteArray?> = inputChannel
|
||||
override val output: SendChannel<UByteArray> = outputChannel
|
||||
|
20
src/commonMain/kotlin/net/sergeych/tools/flow_tools.kt
Normal file
20
src/commonMain/kotlin/net/sergeych/tools/flow_tools.kt
Normal file
@ -0,0 +1,20 @@
|
||||
package net.sergeych.tools
|
||||
|
||||
import kotlinx.coroutines.cancel
|
||||
import kotlinx.coroutines.coroutineScope
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
/**
|
||||
* suspend until the flow produces the value to which the
|
||||
* predicate returns true
|
||||
*/
|
||||
suspend fun <T>Flow<T>.waitFor(predicate: (T)->Boolean) {
|
||||
coroutineScope {
|
||||
launch {
|
||||
collect {
|
||||
if( predicate(it) ) cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -3,13 +3,14 @@ package net.sergeych.kiloparsec.adapter
|
||||
import kotlinx.coroutines.*
|
||||
import kotlinx.coroutines.channels.Channel
|
||||
import kotlinx.coroutines.channels.ClosedReceiveChannelException
|
||||
import net.sergeych.bintools.toDump
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import net.sergeych.crypto.encodeVarUnsigned
|
||||
import net.sergeych.crypto.readVarUnsigned
|
||||
import net.sergeych.kiloparsec.Transport
|
||||
import net.sergeych.mp_logger.LogTag
|
||||
import net.sergeych.mp_logger.warning
|
||||
import net.sergeych.mp_tools.globalLaunch
|
||||
import net.sergeych.tools.waitFor
|
||||
import java.net.InetSocketAddress
|
||||
import java.net.StandardSocketOptions.TCP_NODELAY
|
||||
import java.nio.ByteBuffer
|
||||
@ -22,8 +23,7 @@ private val log = LogTag("ASTD")
|
||||
/**
|
||||
* Prepend block with its size, varint-encoded
|
||||
*/
|
||||
private fun encode(block: UByteArray): ByteArray
|
||||
= (encodeVarUnsigned(block.size.toUInt()) + block).toByteArray()
|
||||
private fun encode(block: UByteArray): ByteArray = (encodeVarUnsigned(block.size.toUInt()) + block).toByteArray()
|
||||
|
||||
/**
|
||||
* Convert asynchronous socket to a [Transport.Device] using non-blocking nio,
|
||||
@ -35,14 +35,19 @@ suspend fun asyncSocketToDevice(socket: AsynchronousSocketChannel): InetTranspor
|
||||
val deferredDevice = CompletableDeferred<InetTransportDevice>()
|
||||
globalLaunch {
|
||||
coroutineScope {
|
||||
val sendQueueEmpty = MutableStateFlow(true)
|
||||
val receiving = MutableStateFlow(false)
|
||||
fun stop() {
|
||||
cancel()
|
||||
}
|
||||
// We're in block mode, every block we send worth immediate sending, we do not
|
||||
// send partial blocks, so:
|
||||
socket.setOption(TCP_NODELAY, true)
|
||||
|
||||
// socket input is to be parsed for blocks, so we receive bytes
|
||||
// and decode them to blocks
|
||||
val input = Channel<UByte>(1024)
|
||||
// copy from socket to input:
|
||||
// copy incoming data from the socket to input channel:
|
||||
launch {
|
||||
val data = ByteArray(1024)
|
||||
val inb = ByteBuffer.wrap(data)
|
||||
@ -53,32 +58,41 @@ suspend fun asyncSocketToDevice(socket: AsynchronousSocketChannel): InetTranspor
|
||||
}
|
||||
if (size < 0) stop()
|
||||
else {
|
||||
println("recvd:\n${data.sliceArray(0..<size).toDump()}\n------------------")
|
||||
// println("recvd:\n${data.sliceArray(0..<size).toDump()}\n------------------")
|
||||
for (i in 0..<size) input.send(data[i].toUByte())
|
||||
}
|
||||
}
|
||||
}
|
||||
// output is blocks, so we sent blocks:
|
||||
|
||||
// output is blocks, so we sent transformed, framed blocks:
|
||||
val outputBlocks = Channel<UByteArray>()
|
||||
// copy from output to socket:
|
||||
launch {
|
||||
try {
|
||||
while (isActive) {
|
||||
// wait for the first block to send
|
||||
sendQueueEmpty.value = outputBlocks.isEmpty
|
||||
var data = encode(outputBlocks.receive())
|
||||
|
||||
// now we're sending, so queue state is sending:
|
||||
sendQueueEmpty.value = false
|
||||
|
||||
// if there are more, take them all (NO_DELAY optimization)
|
||||
while (!outputBlocks.isEmpty)
|
||||
data += encode(outputBlocks.receive())
|
||||
// now send the aggregate:
|
||||
|
||||
// now send it all together:
|
||||
val outBuff = ByteBuffer.wrap(data)
|
||||
val cnt = suspendCoroutine { continuation ->
|
||||
socket.write(outBuff, continuation, IntCompletionHandler)
|
||||
}
|
||||
// be sure it was all sent
|
||||
if (outBuff.position() != data.size || cnt != data.size) {
|
||||
throw RuntimeException("PArtial write!")
|
||||
throw RuntimeException("unexpected partial write")
|
||||
}
|
||||
}
|
||||
// in the case of just breaking out of the loop:
|
||||
sendQueueEmpty.value = true
|
||||
} catch (_: ClosedReceiveChannelException) {
|
||||
stop()
|
||||
}
|
||||
@ -89,7 +103,9 @@ suspend fun asyncSocketToDevice(socket: AsynchronousSocketChannel): InetTranspor
|
||||
launch {
|
||||
try {
|
||||
while (isActive) {
|
||||
receiving.value = !input.isEmpty
|
||||
val size = readVarUnsigned(input)
|
||||
receiving.value = true
|
||||
if (size == 0u) log.warning { "zero size block is ignored!" }
|
||||
else {
|
||||
val block = UByteArray(size.toInt())
|
||||
@ -105,11 +121,33 @@ suspend fun asyncSocketToDevice(socket: AsynchronousSocketChannel): InetTranspor
|
||||
inputBlocks.send(null)
|
||||
stop()
|
||||
}
|
||||
receiving.value = false
|
||||
}
|
||||
// SocketAddress.
|
||||
|
||||
// wait until send queue is empty
|
||||
suspend fun flush() {
|
||||
yield()
|
||||
// do not slow down with collect if it is ok by now:
|
||||
if (!sendQueueEmpty.value || !outputBlocks.isEmpty)
|
||||
// wait until all output is sent
|
||||
sendQueueEmpty.waitFor { it && outputBlocks.isEmpty }
|
||||
}
|
||||
|
||||
val addr = socket.remoteAddress as InetSocketAddress
|
||||
deferredDevice.complete(
|
||||
InetTransportDevice(inputBlocks, outputBlocks, JvmNetworkAddress(addr.address,addr.port)) { stop() }
|
||||
InetTransportDevice(inputBlocks, outputBlocks, JvmNetworkAddress(addr.address, addr.port),
|
||||
{ flush() }
|
||||
) {
|
||||
yield()
|
||||
// wait until all received data are parsed, but not too long
|
||||
withTimeoutOrNull( 1000 ) {
|
||||
receiving.waitFor { !it }
|
||||
}
|
||||
// graceful close: flush output
|
||||
flush()
|
||||
// then stop it
|
||||
stop()
|
||||
}
|
||||
)
|
||||
}
|
||||
globalLaunch { socket.close() }
|
||||
|
@ -9,8 +9,10 @@ import net.sergeych.kiloparsec.adapter.acceptTcpDevice
|
||||
import net.sergeych.kiloparsec.adapter.connectTcpDevice
|
||||
import net.sergeych.kiloparsec.adapter.toNetworkAddress
|
||||
import net.sergeych.mp_logger.Log
|
||||
import net.sergeych.tools.ProtectedOp
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertContains
|
||||
|
||||
class NetworkTest {
|
||||
|
||||
@ -35,6 +37,8 @@ class NetworkTest {
|
||||
|
||||
coroutineScope {
|
||||
val serverFlow = acceptTcpDevice(17171)
|
||||
val op = ProtectedOp()
|
||||
var pills = setOf<String>()
|
||||
val j = launch {
|
||||
serverFlow.collect { device ->
|
||||
launch {
|
||||
@ -42,12 +46,13 @@ class NetworkTest {
|
||||
device.output.send("Great".encodeToUByteArray())
|
||||
while (true) {
|
||||
val x = device.input.receive()?.decodeFromUByteArray() ?: break
|
||||
if (x == "Goodbye") break
|
||||
if (x == "die") {
|
||||
println("collector get poisoned pill")
|
||||
cancel()
|
||||
break
|
||||
if (x.startsWith("die")) {
|
||||
op {
|
||||
pills += x
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
else
|
||||
println("ignoring unexpected input: $x")
|
||||
}
|
||||
}
|
||||
@ -59,14 +64,20 @@ class NetworkTest {
|
||||
assertEquals("Hello, world!", s.input.receive()!!.decodeFromUByteArray())
|
||||
assertEquals("Great", s.input.receive()!!.decodeFromUByteArray())
|
||||
s.output.send("Goodbye".encodeToUByteArray())
|
||||
s.output.send("die1".encodeToUByteArray())
|
||||
s.close()
|
||||
}
|
||||
val s1 = connectTcpDevice("127.0.1.1:17171".toNetworkAddress())
|
||||
assertEquals("Hello, world!", s1.input.receive()!!.decodeFromUByteArray())
|
||||
assertEquals("Great", s1.input.receive()!!.decodeFromUByteArray())
|
||||
s1.output.send("die".encodeToUByteArray())
|
||||
delay(200)
|
||||
s1.output.send("die2".encodeToUByteArray())
|
||||
s1.close()
|
||||
|
||||
// check that channels were flushed prior to closed:
|
||||
assertContains(pills, "die1")
|
||||
assertContains(pills, "die2")
|
||||
|
||||
// Check that server jobs are closed
|
||||
j.cancelAndJoin()
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user