diff --git a/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/symmetric/Aes.kt b/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/symmetric/Aes.kt index 79b4a48..ecc0a8a 100644 --- a/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/symmetric/Aes.kt +++ b/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/symmetric/Aes.kt @@ -92,6 +92,24 @@ class Aes { stateMatrix[3] = arrayOf(stateMatrix[3][3], stateMatrix[3][0], stateMatrix[3][1], stateMatrix[3][2]) } + fun mixColumns() { + val stateMixed : Array> = (0 until 4).map { + Array(4) { 0U } + }.toTypedArray() + for (c in 0 .. 3) { + + stateMixed[0][c] = (2U gfm stateMatrix[0][c]) xor galoisFieldMultiply(3U, stateMatrix[1][c]) xor stateMatrix[2][c] xor stateMatrix[3][c] + stateMixed[1][c] = stateMatrix[0][c] xor (2U gfm stateMatrix[1][c]) xor (3U gfm stateMatrix[2][c]) xor stateMatrix[3][c] + stateMixed[2][c] = stateMatrix[0][c] xor stateMatrix[1][c] xor (2U gfm stateMatrix[2][c]) xor (3U gfm stateMatrix[3][c]) + stateMixed[3][c] = 3U gfm stateMatrix[0][c] xor stateMatrix[1][c] xor stateMatrix[2][c] xor (2U gfm stateMatrix[3][c]) + } + stateMixed.copyInto(stateMatrix) + } + + fun galoisFieldAdd(first : UByte, second : UByte) : UByte { + return first xor second + } + fun galoisFieldMultiply(first : UByte, second : UByte) : UByte { var result : UInt = 0U var firstInt = first.toUInt() @@ -112,6 +130,10 @@ class Aes { return result.toUByte() } + infix fun UInt.gfm(second : UByte) : UByte { + return galoisFieldMultiply(this.toUByte(), second) + } + fun expandKey(key: AesKey) { } diff --git a/multiplatform-crypto/src/commonTest/kotlin/com/ionspin/kotlin/crypto/symmetric/AesTest.kt b/multiplatform-crypto/src/commonTest/kotlin/com/ionspin/kotlin/crypto/symmetric/AesTest.kt index a6dd137..381c9b6 100644 --- a/multiplatform-crypto/src/commonTest/kotlin/com/ionspin/kotlin/crypto/symmetric/AesTest.kt +++ b/multiplatform-crypto/src/commonTest/kotlin/com/ionspin/kotlin/crypto/symmetric/AesTest.kt @@ -12,15 +12,15 @@ class AesTest { @Test fun testSubBytes() { val fakeState = arrayOf( - ubyteArrayOf(0x53U, 0U, 0U, 0U).toTypedArray(), - ubyteArrayOf(0U, 0U, 0U, 0U).toTypedArray(), - ubyteArrayOf(0U, 0U, 0U, 0U).toTypedArray(), - ubyteArrayOf(0U, 0U, 0U, 0U).toTypedArray() - ) + ubyteArrayOf(0x53U, 0U, 0U, 0U).toTypedArray(), + ubyteArrayOf(0U, 0U, 0U, 0U).toTypedArray(), + ubyteArrayOf(0U, 0U, 0U, 0U).toTypedArray(), + ubyteArrayOf(0U, 0U, 0U, 0U).toTypedArray() + ) val aes = Aes() fakeState.copyInto(aes.stateMatrix) aes.subBytes() - aes.stateMatrix.forEach{ + aes.stateMatrix.forEach { println(it.joinToString { it.toString(16) }) } assertTrue { @@ -31,21 +31,21 @@ class AesTest { @Test fun testShiftRows() { val fakeState = arrayOf( - ubyteArrayOf(0U, 1U, 2U, 3U).toTypedArray(), - ubyteArrayOf(0U, 1U, 2U, 3U).toTypedArray(), - ubyteArrayOf(0U, 1U, 2U, 3U).toTypedArray(), - ubyteArrayOf(0U, 1U, 2U, 3U).toTypedArray() + ubyteArrayOf(0U, 1U, 2U, 3U).toTypedArray(), + ubyteArrayOf(0U, 1U, 2U, 3U).toTypedArray(), + ubyteArrayOf(0U, 1U, 2U, 3U).toTypedArray(), + ubyteArrayOf(0U, 1U, 2U, 3U).toTypedArray() ) val expectedState = arrayOf( - ubyteArrayOf(0U, 1U, 2U, 3U).toTypedArray(), - ubyteArrayOf(1U, 2U, 3U, 0U).toTypedArray(), - ubyteArrayOf(2U, 3U, 0U, 1U).toTypedArray(), - ubyteArrayOf(3U, 0U, 1U, 2U).toTypedArray() + ubyteArrayOf(0U, 1U, 2U, 3U).toTypedArray(), + ubyteArrayOf(1U, 2U, 3U, 0U).toTypedArray(), + ubyteArrayOf(2U, 3U, 0U, 1U).toTypedArray(), + ubyteArrayOf(3U, 0U, 1U, 2U).toTypedArray() ) val aes = Aes() fakeState.copyInto(aes.stateMatrix) aes.shiftRows() - aes.stateMatrix.forEach{ + aes.stateMatrix.forEach { println(it.joinToString { it.toString(16) }) } assertTrue { @@ -55,12 +55,52 @@ class AesTest { @Test fun testGaloisMultiply() { - val a = 0x57U - val b = 0x83U - val aes = Aes() - val c = aes.galoisFieldMultiply(a.toUByte(), b.toUByte()) + //Samples from FIPS-197 assertTrue { + val a = 0x57U + val b = 0x83U + val aes = Aes() + val c = aes.galoisFieldMultiply(a.toUByte(), b.toUByte()) c == 0xC1U.toUByte() } + + assertTrue { + val a = 0x57U + val b = 0x13U + val aes = Aes() + val c = aes.galoisFieldMultiply(a.toUByte(), b.toUByte()) + c == 0xFEU.toUByte() + } + + + } + + @Test + fun testMixColumns() { + //Test vectors from wikipedia + val fakeState = arrayOf( + ubyteArrayOf(0xdbU, 0xf2U, 0x01U, 0xc6U).toTypedArray(), + ubyteArrayOf(0x13U, 0x0aU, 0x01U, 0xc6U).toTypedArray(), + ubyteArrayOf(0x53U, 0x22U, 0x01U, 0xc6U).toTypedArray(), + ubyteArrayOf(0x45U, 0x5cU, 0x01U, 0xc6U).toTypedArray() + ) + + val expectedState = arrayOf( + ubyteArrayOf(0x8eU, 0x9fU, 0x01U, 0xc6U).toTypedArray(), + ubyteArrayOf(0x4dU, 0xdcU, 0x01U, 0xc6U).toTypedArray(), + ubyteArrayOf(0xa1U, 0x58U, 0x01U, 0xc6U).toTypedArray(), + ubyteArrayOf(0xbcU, 0x9dU, 0x01U, 0xc6U).toTypedArray() + ) + + val aes = Aes() + fakeState.copyInto(aes.stateMatrix) + aes.mixColumns() + aes.stateMatrix.forEach { + println(it.joinToString { it.toString(16) }) + } + assertTrue { + aes.stateMatrix.contentDeepEquals(expectedState) + } + } } \ No newline at end of file