141 lines
5.6 KiB
Kotlin

package net.sergeych.bipack
import kotlinx.datetime.Instant
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.StructureKind
import kotlinx.serialization.encoding.AbstractDecoder
import kotlinx.serialization.encoding.CompositeDecoder
import kotlinx.serialization.modules.EmptySerializersModule
import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.serializer
import net.sergeych.bintools.*
/**
* Decode BiPack format. Note that it relies on [DataSource] so can throw [DataSource.EndOfData]
* excpetion. Specific frames when used can throw [InvalidFrameException] and its derivatives.e
*/
@Suppress("UNCHECKED_CAST")
class BipackDecoder(
val input: DataSource, var elementsCount: Int = 0, val isCollection: Boolean = false,
val hasFixedSize: Boolean = false,
) : AbstractDecoder() {
private var elementIndex = 0
private var nextIsUnsigned = false
private var fixedSize = -1
private var fixedNumber = false
override val serializersModule: SerializersModule = EmptySerializersModule()
override fun decodeBoolean(): Boolean = input.readByte().toInt() != 0
override fun decodeByte(): Byte = input.readByte()
override fun decodeShort(): Short =
if( fixedNumber ) input.readI16()
else if (nextIsUnsigned)
input.readNumber<UInt>().toShort()
else
input.readNumber()
override fun decodeInt(): Int =
if (fixedNumber) input.readI32()
else if (nextIsUnsigned) input.readNumber<UInt>().toInt() else input.readNumber()
override fun decodeLong(): Long =
if( fixedNumber ) input.readI64()
else if (nextIsUnsigned) input.readNumber<ULong>().toLong() else input.readNumber()
override fun decodeFloat(): Float = input.readFloat()
override fun decodeDouble(): Double = input.readDouble()
override fun decodeChar(): Char = Char(input.readNumber<UInt>().toInt())
fun readBytes(): ByteArray {
val length = input.readNumber<UInt>()
return input.readBytes(length.toInt())
}
override fun decodeString(): String = readBytes().decodeToString()
override fun decodeEnum(enumDescriptor: SerialDescriptor): Int = input.readNumber<UInt>().toInt()
override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
if (elementIndex >= elementsCount) return CompositeDecoder.DECODE_DONE
nextIsUnsigned = false
for (a in descriptor.getElementAnnotations(elementIndex)) {
when (a) {
is Unsigned -> nextIsUnsigned = true
is FixedSize -> fixedSize = a.size
is Fixed -> fixedNumber = true
}
}
return elementIndex++
}
override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T {
return if( deserializer == Instant.serializer() )
Instant.fromEpochMilliseconds(decodeLong()) as T
else
super.decodeSerializableValue(deserializer)
}
override fun decodeSequentially(): Boolean = isCollection
override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
val isCollection = descriptor.kind == StructureKind.LIST || descriptor.kind == StructureKind.MAP
val source = if (descriptor.annotations.any { it is CrcProtected })
CRC32Source(input)
else
input
// Note: we should read from 'source' explicitely as it might ve
// CRC-calculating one, and the fields below are CRC protected too:
var count = if (fixedSize >= 0) fixedSize else descriptor.elementsCount
for (a in descriptor.annotations) {
if (a is Extendable)
count = source.readVarUInt().toInt()
else if (a is Framed) {
val code = CRC.crc32(descriptor.serialName.encodeToByteArray())
// if we fail to read CRC, it is IO error, so DataSource.EndOfData will be
// thrown here, and it is better than invalid frame exception:
val actual = source.readU32()
if (code != actual)
throw InvalidFrameHeaderException()
}
}
// println("bestr ${descriptor.serialName} d/r ${descriptor.elementsCount}/$count")
return BipackDecoder(source, count, isCollection, fixedSize >= 0)
}
override fun decodeCollectionSize(descriptor: SerialDescriptor): Int {
return if (hasFixedSize)
elementsCount
else
input.readNumber<UInt>().toInt()
}
override fun endStructure(descriptor: SerialDescriptor) {
if (input is CRC32Source && descriptor.annotations.any { it is CrcProtected }) {
val actual = input.crc
val expected = input.readU32()
if (actual != expected)
throw InvalidFrameCRCException()
}
super.endStructure(descriptor)
}
override fun decodeNotNullMark(): Boolean = decodeBoolean()
@ExperimentalSerializationApi
override fun decodeNull(): Nothing? = null
companion object {
fun <T> decode(source: DataSource, deserializer: DeserializationStrategy<T>): T =
BipackDecoder(source).decodeSerializableValue(deserializer)
@Suppress("unused")
inline fun <reified T> decode(source: DataSource): T = decode(source, serializer())
inline fun <reified T> decode(source: ByteArray): T =
decode(source.toDataSource(), serializer())
}
}
inline fun <reified T> ByteArray.decodeFromBipack() = BipackDecoder.decode<T>(this)