package net.sergeych.bipack import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.KSerializer import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.descriptors.StructureKind import kotlinx.serialization.encoding.AbstractDecoder import kotlinx.serialization.encoding.CompositeDecoder import kotlinx.serialization.modules.EmptySerializersModule import kotlinx.serialization.modules.SerializersModule import kotlinx.serialization.serializer import net.sergeych.bintools.* import kotlin.time.Instant /** * Decode BiPack format. Note that it relies on [DataSource] so can throw [DataSource.EndOfData] * excpetion. Specific frames when used can throw [InvalidFrameException] and its derivatives.e */ @Suppress("UNCHECKED_CAST") class BipackDecoder( val input: DataSource, var elementsCount: Int = 0, val isCollection: Boolean = false, val hasFixedSize: Boolean = false, ) : AbstractDecoder() { private var elementIndex = 0 private var nextIsUnsigned = false private var nextIsVarint = false private var fixedSize = -1 private var fixedNumber = false override val serializersModule: SerializersModule = EmptySerializersModule() override fun decodeBoolean(): Boolean = input.readByte().toInt() != 0 override fun decodeByte(): Byte = input.readByte() override fun decodeShort(): Short = if (fixedNumber) input.readI16() else if (nextIsVarint) if (nextIsUnsigned) input.readVarUInt().toShort() else input.readVarInt().toShort() else if (nextIsUnsigned) input.readNumber().toShort() else input.readNumber() override fun decodeInt(): Int = if (fixedNumber) input.readI32() else if (nextIsVarint) if (nextIsUnsigned) input.readVarUInt().toInt() else input.readVarInt() else if (nextIsUnsigned) input.readNumber().toInt() else input.readNumber() override fun decodeLong(): Long = if (fixedNumber) input.readI64() else if (nextIsVarint) if (nextIsUnsigned) net.sergeych.bintools.Varint.decodeUnsigned(input).toLong() else net.sergeych.bintools.Varint.decodeSigned(input) else if (nextIsUnsigned) input.readNumber().toLong() else input.readNumber() override fun decodeFloat(): Float = input.readFloat() override fun decodeDouble(): Double = input.readDouble() override fun decodeChar(): Char = Char(input.readNumber().toInt()) fun readBytes(): ByteArray { val length = input.readNumber() return input.readBytes(length.toInt()) } override fun decodeString(): String = readBytes().decodeToString() override fun decodeEnum(enumDescriptor: SerialDescriptor): Int = input.readNumber().toInt() override fun decodeElementIndex(descriptor: SerialDescriptor): Int { if (elementIndex >= elementsCount) return CompositeDecoder.DECODE_DONE nextIsUnsigned = false nextIsVarint = false for (a in descriptor.getElementAnnotations(elementIndex)) { when (a) { is Unsigned -> nextIsUnsigned = true is Varint -> nextIsVarint = true is FixedSize -> fixedSize = a.size is Fixed -> fixedNumber = true } } return elementIndex++ } override fun decodeInline(descriptor: SerialDescriptor): BipackDecoder { if (descriptor.isUnsignedInlinePrimitive()) nextIsUnsigned = true return this } override fun decodeSerializableValue(deserializer: DeserializationStrategy): T { return if (deserializer == serializer()) Instant.fromEpochMilliseconds(decodeLong()) as T else super.decodeSerializableValue(deserializer) } override fun decodeSequentially(): Boolean = isCollection override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder { val isCollection = descriptor.kind == StructureKind.LIST || descriptor.kind == StructureKind.MAP val source = if (descriptor.annotations.any { it is CrcProtected }) CRC32Source(input) else input // Note: we should read from 'source' explicitely as it might ve // CRC-calculating one, and the fields below are CRC protected too: var count = if (fixedSize >= 0) fixedSize else descriptor.elementsCount for (a in descriptor.annotations) { if (a is Extendable) count = source.readVarUInt().toInt() else if (a is Framed) { val code = CRC.crc32(descriptor.serialName.encodeToByteArray()) // if we fail to read CRC, it is IO error, so DataSource.EndOfData will be // thrown here, and it is better than invalid frame exception: val actual = source.readU32() if (code != actual) throw InvalidFrameHeaderException() } } // println("bestr ${descriptor.serialName} d/r ${descriptor.elementsCount}/$count") return BipackDecoder(source, count, isCollection, fixedSize >= 0) } override fun decodeCollectionSize(descriptor: SerialDescriptor): Int { return if (hasFixedSize) elementsCount else input.readNumber().toInt() } override fun endStructure(descriptor: SerialDescriptor) { if (input is CRC32Source && descriptor.annotations.any { it is CrcProtected }) { val actual = input.crc val expected = input.readU32() if (actual != expected) throw InvalidFrameCRCException() } super.endStructure(descriptor) } override fun decodeNotNullMark(): Boolean = try { decodeBoolean() } catch (_: DataSource.EndOfData) { false } @ExperimentalSerializationApi override fun decodeNull(): Nothing? = null private fun SerialDescriptor.isUnsignedInlinePrimitive(): Boolean = isInline && serialName in setOf("kotlin.UInt", "kotlin.ULong", "kotlin.UShort", "kotlin.UByte") companion object { fun decode(source: DataSource, deserializer: DeserializationStrategy): T = BipackDecoder(source).decodeSerializableValue(deserializer) @Suppress("unused") inline fun decode(source: DataSource): T = decode(source, serializer()) inline fun decode(source: ByteArray): T = decode(source.toDataSource(), serializer()) fun decode(serializer: KSerializer, source: ByteArray): T = decode(source.toDataSource(), serializer) inline fun decode(source: UByteArray): T = decode(source.toDataSource(), serializer()) fun decode(serializer: KSerializer, source: UByteArray): T = decode(source.toDataSource(), serializer) } } inline fun ByteArray.decodeFromBipack() = BipackDecoder.decode(this) @Suppress("unused") inline fun UByteArray.decodeFromBipack() = BipackDecoder.decode(this)