diff --git a/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/symmetric/Salsa20.kt b/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/symmetric/Salsa20.kt index 82c488b..dd2d0e2 100644 --- a/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/symmetric/Salsa20.kt +++ b/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/symmetric/Salsa20.kt @@ -1,6 +1,7 @@ package com.ionspin.kotlin.crypto.symmetric -import com.ionspin.kotlin.crypto.util.rotateLeft +import com.ionspin.kotlin.crypto.keyderivation.argon2.xorWithBlock +import com.ionspin.kotlin.crypto.util.* /** * Created by Ugljesa Jovanovic @@ -74,11 +75,8 @@ class Salsa20 { output[outputPosition + 3] = ((input shr 24) and 0xFFU).toUByte() } - fun hash(input: UByteArray): UByteArray { - val state = UIntArray(16) { - littleEndian(input, (it * 4) + 0, (it * 4) + 1, (it * 4) + 2, (it * 4) + 3) - } - val initialState = state.copyOf() + fun hash(initialState: UIntArray): UByteArray { + val state = initialState.copyOf() for (i in 0 until 10) { doubleRound(state) } @@ -94,17 +92,58 @@ class Salsa20 { val sigma2_32 = ubyteArrayOf(50U, 45U, 98U, 121U) val sigma3_32 = ubyteArrayOf(116U, 101U, 32U, 107U) - val sigma0_16 = ubyteArrayOf(101U, 120U, 112U, 97U) - val sigma1_16 = ubyteArrayOf(110U, 100U, 32U, 49U) - val sigma2_16 = ubyteArrayOf(54U, 45U, 98U, 121U) - val sigma3_16 = ubyteArrayOf(116U, 101U, 32U, 107U) + val tau0_16 = ubyteArrayOf(101U, 120U, 112U, 97U) + val tau1_16 = ubyteArrayOf(110U, 100U, 32U, 49U) + val tau2_16 = ubyteArrayOf(54U, 45U, 98U, 121U) + val tau3_16 = ubyteArrayOf(116U, 101U, 32U, 107U) fun expansion16(k: UByteArray, n: UByteArray) : UByteArray { - return hash(sigma0_16 + k + sigma1_16 + n + sigma2_16 + k + sigma3_16) + return hash((tau0_16 + k + tau1_16 + n + tau2_16 + k + tau3_16).fromLittleEndianToUInt()) } - fun expansion32(k:UByteArray, n: UByteArray) : UByteArray { - return hash(sigma0_32 + k.slice(0 until 16) + sigma1_32 + n + sigma2_32 + k.slice(16 until 32) + sigma3_32) + fun expansion32(key :UByteArray, nonce : UByteArray) : UByteArray { + return hash((sigma0_32 + key.slice(0 until 16) + sigma1_32 + nonce + sigma2_32 + key.slice(16 until 32) + sigma3_32).fromLittleEndianToUInt()) + } + + fun encrypt(key : UByteArray, nonce: UByteArray, message: UByteArray) : UByteArray { + val ciphertext = UByteArray(message.size) + val state = UIntArray(16) { + when (it) { + 0 -> sigma0_32.fromLittleEndianArrayToUInt() + 1 -> key.fromLittleEndianArrayToUIntWithPosition(0) + 2 -> key.fromLittleEndianArrayToUIntWithPosition(4) + 3 -> key.fromLittleEndianArrayToUIntWithPosition(8) + 4 -> key.fromLittleEndianArrayToUIntWithPosition(12) + 5 -> sigma1_32.fromLittleEndianArrayToUInt() + 6 -> nonce.fromLittleEndianArrayToUIntWithPosition(0) + 7 -> nonce.fromLittleEndianArrayToUIntWithPosition(4) + 8 -> 0U + 9 -> 0U + 10 -> sigma2_32.fromLittleEndianArrayToUInt() + 11 -> key.fromLittleEndianArrayToUIntWithPosition(16) + 12 -> key.fromLittleEndianArrayToUIntWithPosition(20) + 13 -> key.fromLittleEndianArrayToUIntWithPosition(24) + 14 -> key.fromLittleEndianArrayToUIntWithPosition(28) + 15 -> sigma3_32.fromLittleEndianArrayToUInt() + else -> 0U + } + } + val remainder = message.size % 64 + for (i in 0 until message.size - 64 step 64) { + hash(state).xorWithPositionsAndInsertIntoArray(0, 64, message, i, ciphertext, i) + state[8] += 1U + if (state[8] == 0U) { + state[9] += 1U + } + } + for ( i in message.size - (64 - remainder) until message.size step 64) { + hash(state).xorWithPositionsAndInsertIntoArray(0, (64 - remainder), message, i, ciphertext, i) + state[8] += 1U + if (state[8] == 0U) { + state[9] += 1U + } + } + return ciphertext } } diff --git a/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/util/Util.kt b/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/util/Util.kt index a654ea5..d2ea19c 100644 --- a/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/util/Util.kt +++ b/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/util/Util.kt @@ -18,7 +18,8 @@ package com.ionspin.kotlin.crypto.util -import com.ionspin.kotlin.crypto.keyderivation.argon2.ArgonBlockPointer +import com.ionspin.kotlin.crypto.keyderivation.argon2.Argon2Utils +import com.ionspin.kotlin.crypto.keyderivation.argon2.xorWithBlock /** * Created by Ugljesa Jovanovic @@ -98,6 +99,21 @@ infix fun UByteArray.xor(other : UByteArray) : UByteArray { return UByteArray(this.size) { this[it] xor other[it] } } +fun UByteArray.xorWithPositions(start: Int, end: Int, other : UByteArray, otherStart: Int) : UByteArray { + return UByteArray(end - start) { this[start + it] xor other[otherStart + it] } +} + +fun UByteArray.xorWithPositionsAndInsertIntoArray( + start: Int, end: Int, + other : UByteArray, otherStart: Int, + targetArray: UByteArray, targetStart : Int) { + for (i in start until end) { + if (targetStart + i == 131071) { + println("stop") + } + targetArray[targetStart + i] = this[start + i] xor other[otherStart + i] + } +} fun String.hexStringToTypedUByteArray() : Array { return this.chunked(2).map { it.toUByte(16) }.toTypedArray() @@ -241,7 +257,7 @@ fun UByteArray.fromLittleEndianArrayToUInt() : UInt { return uint } -fun UByteArray.fromLittleEndianArrayToUintWithPosition(position: Int) : UInt{ +fun UByteArray.fromLittleEndianArrayToUIntWithPosition(position: Int) : UInt{ var uint = 0U for (i in 0 until 4) { uint = uint or (this[position + i].toUInt() shl (i * 8)) @@ -249,7 +265,15 @@ fun UByteArray.fromLittleEndianArrayToUintWithPosition(position: Int) : UInt{ return uint } -fun UByteArray.fromBigEndianArrayToUintWithPosition(position: Int) : UInt{ +fun UByteArray.fromBigEndianArrayToUInt() : UInt{ + var uint = 0U + for (i in 0 until 4) { + uint = uint shl 8 or (this[i].toUInt()) + } + return uint +} + +fun UByteArray.fromBigEndianArrayToUIntWithPosition(position: Int) : UInt{ var uint = 0U for (i in 0 until 4) { uint = uint shl 8 or (this[position + i].toUInt()) @@ -269,6 +293,25 @@ fun UByteArray.insertUIntAtPositionAsBigEndian(position: Int, value: UInt) { } } +fun UByteArray.fromLittleEndianToUInt() : UIntArray { + if (size % 4 != 0) { + throw RuntimeException("Invalid size (not divisible by 4)") + } + return UIntArray(size / 4) { + fromLittleEndianArrayToUIntWithPosition(it * 4) + } +} + + +fun UByteArray.fromBigEndianToUInt() : UIntArray { + if (size % 4 != 0) { + throw RuntimeException("Invalid size (not divisible by 4)") + } + return UIntArray(size / 4) { + fromBigEndianArrayToUIntWithPosition(it * 4) + } +} + fun Array.fromBigEndianArrayToUInt() : UInt { if (this.size > 4) { throw RuntimeException("ore than 8 bytes in input, potential overflow") diff --git a/multiplatform-crypto/src/commonTest/kotlin/com/ionspin/kotlin/crypto/symmetric/Salsa20Test.kt b/multiplatform-crypto/src/commonTest/kotlin/com/ionspin/kotlin/crypto/symmetric/Salsa20Test.kt index f16b489..ff284e0 100644 --- a/multiplatform-crypto/src/commonTest/kotlin/com/ionspin/kotlin/crypto/symmetric/Salsa20Test.kt +++ b/multiplatform-crypto/src/commonTest/kotlin/com/ionspin/kotlin/crypto/symmetric/Salsa20Test.kt @@ -1,6 +1,9 @@ package com.ionspin.kotlin.crypto.symmetric +import com.ionspin.kotlin.crypto.util.fromLittleEndianToUInt +import com.ionspin.kotlin.crypto.util.hexStringToUByteArray import com.ionspin.kotlin.crypto.util.rotateLeft +import com.ionspin.kotlin.crypto.util.toHexString import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertTrue @@ -266,7 +269,7 @@ class Salsa20Test { 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U ) - val result = Salsa20.hash(input) + val result = Salsa20.hash(input.fromLittleEndianToUInt()) input.contentEquals(expected) } @@ -283,7 +286,7 @@ class Salsa20Test { 118U, 40U, 152U, 157U, 180U, 57U, 27U, 94U, 107U, 42U, 236U, 35U, 27U, 111U, 114U, 114U, 219U, 236U, 232U, 135U, 111U, 155U, 110U, 18U, 24U, 232U, 95U, 158U, 179U, 19U, 48U, 202U ) - val result = Salsa20.hash(input) + val result = Salsa20.hash(input.fromLittleEndianToUInt()) result.contentEquals(expected) } @@ -301,7 +304,7 @@ class Salsa20Test { 122U, 127U, 195U, 185U, 185U, 204U, 188U, 90U, 245U, 9U, 183U, 248U, 226U, 85U, 245U, 104U ) val result = (0 until 1_000_000).fold(input) { acc, _ -> - Salsa20.hash(acc) + Salsa20.hash(acc.fromLittleEndianToUInt()) } result.contentEquals(expected) } @@ -311,29 +314,80 @@ class Salsa20Test { @Test fun testExpansion() { val k0 = ubyteArrayOf(1U, 2U, 3U, 4U, 5U, 6U, 7U, 8U, 9U, 10U, 11U, 12U, 13U, 14U, 15U, 16U) - val k1 = ubyteArrayOf(201U, 202U, 203U, 204U, 205U, 206U, 207U, 208U, 209U, 210U, 211U, 212U, 213U, 214U, 215U, 216U) - val n = ubyteArrayOf(101U, 102U, 103U, 104U, 105U, 106U, 107U, 108U, 109U, 110U, 111U, 112U, 113U, 114U, 115U, 116U) + val k1 = + ubyteArrayOf(201U, 202U, 203U, 204U, 205U, 206U, 207U, 208U, 209U, 210U, 211U, 212U, 213U, 214U, 215U, 216U) + val n = + ubyteArrayOf(101U, 102U, 103U, 104U, 105U, 106U, 107U, 108U, 109U, 110U, 111U, 112U, 113U, 114U, 115U, 116U) assertTrue { val expected = ubyteArrayOf( - 69U, 37U, 68U, 39U, 41U, 15U,107U,193U,255U,139U,122U, 6U,170U,233U,217U, 98U, - 89U,144U,182U,106U, 21U, 51U,200U, 65U,239U, 49U,222U, 34U,215U,114U, 40U,126U, - 104U,197U, 7U,225U,197U,153U, 31U, 2U,102U, 78U, 76U,176U, 84U,245U,246U,184U, - 177U,160U,133U,130U, 6U, 72U,149U,119U,192U,195U,132U,236U,234U,103U,246U, 74U + 69U, 37U, 68U, 39U, 41U, 15U, 107U, 193U, 255U, 139U, 122U, 6U, 170U, 233U, 217U, 98U, + 89U, 144U, 182U, 106U, 21U, 51U, 200U, 65U, 239U, 49U, 222U, 34U, 215U, 114U, 40U, 126U, + 104U, 197U, 7U, 225U, 197U, 153U, 31U, 2U, 102U, 78U, 76U, 176U, 84U, 245U, 246U, 184U, + 177U, 160U, 133U, 130U, 6U, 72U, 149U, 119U, 192U, 195U, 132U, 236U, 234U, 103U, 246U, 74U ) - val result = Salsa20.expansion32(k0+k1, n) + val result = Salsa20.expansion32(k0 + k1, n) result.contentEquals(expected) } assertTrue { val expected = ubyteArrayOf( - 39U,173U, 46U,248U, 30U,200U, 82U, 17U, 48U, 67U,254U,239U, 37U, 18U, 13U,247U, - 241U,200U, 61U,144U, 10U, 55U, 50U,185U, 6U, 47U,246U,253U,143U, 86U,187U,225U, - 134U, 85U,110U,246U,161U,163U, 43U,235U,231U, 94U,171U, 51U,145U,214U,112U, 29U, - 14U,232U, 5U, 16U,151U,140U,183U,141U,171U, 9U,122U,181U,104U,182U,177U,193U + 39U, 173U, 46U, 248U, 30U, 200U, 82U, 17U, 48U, 67U, 254U, 239U, 37U, 18U, 13U, 247U, + 241U, 200U, 61U, 144U, 10U, 55U, 50U, 185U, 6U, 47U, 246U, 253U, 143U, 86U, 187U, 225U, + 134U, 85U, 110U, 246U, 161U, 163U, 43U, 235U, 231U, 94U, 171U, 51U, 145U, 214U, 112U, 29U, + 14U, 232U, 5U, 16U, 151U, 140U, 183U, 141U, 171U, 9U, 122U, 181U, 104U, 182U, 177U, 193U ) val result = Salsa20.expansion16(k0, n) result.contentEquals(expected) } } + + @Test + fun testSalsa20Encryption() { + assertTrue { + val key = "8000000000000000000000000000000000000000000000000000000000000000".hexStringToUByteArray() + val nonce = "0000000000000000".hexStringToUByteArray() + val expectedStartsWith = ( + "E3BE8FDD8BECA2E3EA8EF9475B29A6E7" + + "003951E1097A5C38D23B7A5FAD9F6844" + + "B22C97559E2723C7CBBD3FE4FC8D9A07" + + "44652A83E72A9C461876AF4D7EF1A117").toLowerCase() + val endsWith = ( + "696AFCFD0CDDCC83C7E77F11A649D79A" + + "CDC3354E9635FF137E929933A0BD6F53" + + "77EFA105A3A4266B7C0D089D08F1E855" + + "CC32B15B93784A36E56A76CC64BC8477" + ).toLowerCase() + + val ciphertext = Salsa20.encrypt(key, nonce, UByteArray(512) { 0U }) + ciphertext.toHexString().toLowerCase().startsWith(expectedStartsWith) && + ciphertext.toHexString().toLowerCase().endsWith(endsWith) + } + + assertTrue { + val key = "0A5DB00356A9FC4FA2F5489BEE4194E73A8DE03386D92C7FD22578CB1E71C417".hexStringToUByteArray() + val nonce = "1F86ED54BB2289F0".hexStringToUByteArray() + val expectedStartsWith = ( + "3FE85D5BB1960A82480B5E6F4E965A44" + + "60D7A54501664F7D60B54B06100A37FF" + + "DCF6BDE5CE3F4886BA77DD5B44E95644" + + "E40A8AC65801155DB90F02522B644023").toLowerCase() + val endsWith = ( + "7998204FED70CE8E0D027B206635C08C" + + "8BC443622608970E40E3AEDF3CE790AE" + + "EDF89F922671B45378E2CD03F6F62356" + + "529C4158B7FF41EE854B1235373988C8" + ).toLowerCase() + + val ciphertext = Salsa20.encrypt(key, nonce, UByteArray(131072) { 0U }) + println(ciphertext.slice(0 until 64).toTypedArray().toHexString()) + println(ciphertext.slice(131008 until 131072).toTypedArray().toHexString()) + ciphertext.slice(0 until 64).toTypedArray().toHexString().toLowerCase().startsWith(expectedStartsWith) && + ciphertext.slice(131008 until 131072).toTypedArray().toHexString().toLowerCase().contains(endsWith) + } + + + } + + } \ No newline at end of file