diff --git a/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/symmetric/Aes.kt b/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/symmetric/Aes.kt index 419665d..b8d272f 100644 --- a/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/symmetric/Aes.kt +++ b/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/symmetric/Aes.kt @@ -4,7 +4,7 @@ package com.ionspin.kotlin.crypto.symmetric * Created by Ugljesa Jovanovic (jovanovic.ugljesa@gmail.com) on 07/Sep/2019 */ @ExperimentalUnsignedTypes -class Aes(val aesKey: AesKey) { +class Aes(val aesKey: AesKey, val input: Array) { companion object { val sBox: UByteArray = ubyteArrayOf( @@ -54,20 +54,33 @@ class Aes(val aesKey: AesKey) { } + init { + if (input.size != 16) { + throw RuntimeException("Invalid input size ${input.size}") + } + } - val state: UByteArray = UByteArray(16) { 0U } - - val stateMatrix: Array> = (0 until 4).map { - Array(4) { 0U } + val state: Array> = (0 until 4).map{ outerCounter -> + Array(4) { innerCounter -> input[innerCounter * 4 + outerCounter] } }.toTypedArray() + val numberOfRounds = when(aesKey) { + is AesKey.Aes128Key -> 10 + is AesKey.Aes192Key -> 12 + is AesKey.Aes256Key -> 14 + } + val expandedKey: Array> = expandKey() + + var round = 0 + + fun subBytes() { - stateMatrix.forEachIndexed { indexRow, row -> + state.forEachIndexed { indexRow, row -> row.forEachIndexed { indexColumn, element -> - stateMatrix[indexRow][indexColumn] = getSBoxValue(element) + state[indexRow][indexColumn] = getSBoxValue(element) } } } @@ -79,10 +92,10 @@ class Aes(val aesKey: AesKey) { } fun shiftRows() { - stateMatrix[0] = arrayOf(stateMatrix[0][0], stateMatrix[0][1], stateMatrix[0][2], stateMatrix[0][3]) - stateMatrix[1] = arrayOf(stateMatrix[1][1], stateMatrix[1][2], stateMatrix[1][3], stateMatrix[1][0]) - stateMatrix[2] = arrayOf(stateMatrix[2][2], stateMatrix[2][3], stateMatrix[2][0], stateMatrix[2][1]) - stateMatrix[3] = arrayOf(stateMatrix[3][3], stateMatrix[3][0], stateMatrix[3][1], stateMatrix[3][2]) + state[0] = arrayOf(state[0][0], state[0][1], state[0][2], state[0][3]) + state[1] = arrayOf(state[1][1], state[1][2], state[1][3], state[1][0]) + state[2] = arrayOf(state[2][2], state[2][3], state[2][0], state[2][1]) + state[3] = arrayOf(state[3][3], state[3][0], state[3][1], state[3][2]) } fun mixColumns() { @@ -91,18 +104,18 @@ class Aes(val aesKey: AesKey) { }.toTypedArray() for (c in 0..3) { - stateMixed[0][c] = (2U gfm stateMatrix[0][c]) xor galoisFieldMultiply( + stateMixed[0][c] = (2U gfm state[0][c]) xor galoisFieldMultiply( 3U, - stateMatrix[1][c] - ) xor stateMatrix[2][c] xor stateMatrix[3][c] + state[1][c] + ) xor state[2][c] xor state[3][c] stateMixed[1][c] = - stateMatrix[0][c] xor (2U gfm stateMatrix[1][c]) xor (3U gfm stateMatrix[2][c]) xor stateMatrix[3][c] + state[0][c] xor (2U gfm state[1][c]) xor (3U gfm state[2][c]) xor state[3][c] stateMixed[2][c] = - stateMatrix[0][c] xor stateMatrix[1][c] xor (2U gfm stateMatrix[2][c]) xor (3U gfm stateMatrix[3][c]) + state[0][c] xor state[1][c] xor (2U gfm state[2][c]) xor (3U gfm state[3][c]) stateMixed[3][c] = - 3U gfm stateMatrix[0][c] xor stateMatrix[1][c] xor stateMatrix[2][c] xor (2U gfm stateMatrix[3][c]) + 3U gfm state[0][c] xor state[1][c] xor state[2][c] xor (2U gfm state[3][c]) } - stateMixed.copyInto(stateMatrix) + stateMixed.copyInto(state) } fun galoisFieldAdd(first: UByte, second: UByte): UByte { @@ -131,6 +144,13 @@ class Aes(val aesKey: AesKey) { fun addRoundKey() { + for (i in 0 until 4) { + state[0][i] = state[0][i] xor expandedKey[round * 4 + i][0] + state[1][i] = state[1][i] xor expandedKey[round * 4 + i][1] + state[2][i] = state[2][i] xor expandedKey[round * 4 + i][2] + state[3][i] = state[3][i] xor expandedKey[round * 4 + i][3] + } + round++ } infix fun UInt.gfm(second: UByte): UByte { @@ -138,7 +158,7 @@ class Aes(val aesKey: AesKey) { } fun expandKey(): Array> { - val expandedKey = (0 until 4 * (aesKey.numberOfRounds + 1)).map { + val expandedKey = (0 until 4 * (numberOfRounds + 1)).map { Array(4) { 0U } }.toTypedArray() // First round @@ -149,7 +169,7 @@ class Aes(val aesKey: AesKey) { expandedKey[i][3] = aesKey.keyArray[i * 4 + 3] } - for (i in aesKey.numberOf32BitWords until 4 * (aesKey.numberOfRounds + 1)) { + for (i in aesKey.numberOf32BitWords until 4 * (numberOfRounds + 1)) { val temp = expandedKey[i - 1].copyOf() if (i % aesKey.numberOf32BitWords == 0) { //RotWord @@ -179,15 +199,44 @@ class Aes(val aesKey: AesKey) { } return expandedKey } + + fun encrypt() : Array { + addRoundKey() + + for (i in 0 until numberOfRounds - 1) { + subBytes() + shiftRows() + mixColumns() + addRoundKey() + } + + subBytes() + shiftRows() + addRoundKey() + return state.flatten().toTypedArray() + } + + fun decrypt() : Array { + return ubyteArrayOf().toTypedArray() + } + + fun printState() { + println() + state.forEach { + println(it.joinToString(separator = " ") { it.toString(16) }) + } + } + + } -sealed class AesKey(val key: String, val keyLength: Int, val numberOfRounds: Int) { +sealed class AesKey(val key: String, val keyLength: Int) { val keyArray: Array = key.chunked(2).map { it.toUByte(16) }.toTypedArray() val numberOf32BitWords = keyLength / 32 - class Aes128Key(key: String) : AesKey(key, 128, 10) - class Aes192Key(key: String) : AesKey(key, 192, 12) - class Aes256Key(key: String) : AesKey(key, 256, 14) + class Aes128Key(key: String) : AesKey(key, 128) + class Aes192Key(key: String) : AesKey(key, 192) + class Aes256Key(key: String) : AesKey(key, 256) init { checkKeyLength(key, keyLength) @@ -199,3 +248,5 @@ sealed class AesKey(val key: String, val keyLength: Int, val numberOfRounds: Int } } } + + diff --git a/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/symmetric/Mode.kt b/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/symmetric/Mode.kt new file mode 100644 index 0000000..a0d43f9 --- /dev/null +++ b/multiplatform-crypto/src/commonMain/kotlin/com/ionspin/kotlin/crypto/symmetric/Mode.kt @@ -0,0 +1,27 @@ +/* + * Copyright 2019 Ugljesa Jovanovic + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ionspin.kotlin.crypto.symmetric + +/** + * Created by Ugljesa Jovanovic + * ugljesa.jovanovic@ionspin.com + * on 18-Sep-2019 + */ + +enum class Mode { + ENCRYPT, DECRYPT +} \ No newline at end of file diff --git a/multiplatform-crypto/src/commonTest/kotlin/com/ionspin/kotlin/crypto/symmetric/AesTest.kt b/multiplatform-crypto/src/commonTest/kotlin/com/ionspin/kotlin/crypto/symmetric/AesTest.kt index e97c21b..bab9e23 100644 --- a/multiplatform-crypto/src/commonTest/kotlin/com/ionspin/kotlin/crypto/symmetric/AesTest.kt +++ b/multiplatform-crypto/src/commonTest/kotlin/com/ionspin/kotlin/crypto/symmetric/AesTest.kt @@ -9,6 +9,9 @@ import kotlin.test.assertTrue @ExperimentalUnsignedTypes class AesTest { + val irrelevantKey = "01234567890123345678901234567890" + val irrelevantInput = UByteArray(16) { 0U }.toTypedArray() + @Test fun testSubBytes() { val fakeState = arrayOf( @@ -17,14 +20,14 @@ class AesTest { ubyteArrayOf(0U, 0U, 0U, 0U).toTypedArray(), ubyteArrayOf(0U, 0U, 0U, 0U).toTypedArray() ) - val aes = Aes(AesKey.Aes128Key("2b7e151628aed2a6abf7158809cf4f3c")) - fakeState.copyInto(aes.stateMatrix) + val aes = Aes(AesKey.Aes128Key(irrelevantKey), irrelevantInput) + fakeState.copyInto(aes.state) aes.subBytes() - aes.stateMatrix.forEach { + aes.state.forEach { println(it.joinToString { it.toString(16) }) } assertTrue { - aes.stateMatrix[0][0] == 0xEDU.toUByte() + aes.state[0][0] == 0xEDU.toUByte() } } @@ -42,14 +45,14 @@ class AesTest { ubyteArrayOf(2U, 3U, 0U, 1U).toTypedArray(), ubyteArrayOf(3U, 0U, 1U, 2U).toTypedArray() ) - val aes = Aes(AesKey.Aes128Key("2b7e151628aed2a6abf7158809cf4f3c")) - fakeState.copyInto(aes.stateMatrix) + val aes = Aes(AesKey.Aes128Key(irrelevantKey), irrelevantInput) + fakeState.copyInto(aes.state) aes.shiftRows() - aes.stateMatrix.forEach { + aes.state.forEach { println(it.joinToString { it.toString(16) }) } assertTrue { - aes.stateMatrix.contentDeepEquals(expectedState) + aes.state.contentDeepEquals(expectedState) } } @@ -59,7 +62,7 @@ class AesTest { assertTrue { val a = 0x57U val b = 0x83U - val aes = Aes(AesKey.Aes128Key("2b7e151628aed2a6abf7158809cf4f3c")) + val aes = Aes(AesKey.Aes128Key(irrelevantKey), irrelevantInput) val c = aes.galoisFieldMultiply(a.toUByte(), b.toUByte()) c == 0xC1U.toUByte() } @@ -67,7 +70,7 @@ class AesTest { assertTrue { val a = 0x57U val b = 0x13U - val aes = Aes(AesKey.Aes128Key("2b7e151628aed2a6abf7158809cf4f3c")) + val aes = Aes(AesKey.Aes128Key(irrelevantKey), irrelevantInput) val c = aes.galoisFieldMultiply(a.toUByte(), b.toUByte()) c == 0xFEU.toUByte() } @@ -92,14 +95,14 @@ class AesTest { ubyteArrayOf(0xbcU, 0x9dU, 0x01U, 0xc6U).toTypedArray() ) - val aes = Aes(AesKey.Aes128Key("2b7e151628aed2a6abf7158809cf4f3c")) - fakeState.copyInto(aes.stateMatrix) + val aes = Aes(AesKey.Aes128Key(irrelevantKey), irrelevantInput) + fakeState.copyInto(aes.state) aes.mixColumns() - aes.stateMatrix.forEach { + aes.state.forEach { println(it.joinToString { it.toString(16) }) } assertTrue { - aes.stateMatrix.contentDeepEquals(expectedState) + aes.state.contentDeepEquals(expectedState) } } @@ -122,7 +125,7 @@ class AesTest { ).toTypedArray() - val aes = Aes(AesKey.Aes128Key(key)) + val aes = Aes(AesKey.Aes128Key(key), irrelevantInput) val result = aes.expandedKey.map { it.foldIndexed(0U) { index, acc, uByte -> acc + (uByte.toUInt() shl (24 - index * 8)) @@ -146,7 +149,7 @@ class AesTest { ).toTypedArray() - val aes = Aes(AesKey.Aes192Key(key)) + val aes = Aes(AesKey.Aes192Key(key), irrelevantInput) val result = aes.expandedKey.map { it.foldIndexed(0U) { index, acc, uByte -> acc + (uByte.toUInt() shl (24 - index * 8)) @@ -172,7 +175,7 @@ class AesTest { ).toTypedArray() - val aes = Aes(AesKey.Aes256Key(key)) + val aes = Aes(AesKey.Aes256Key(key), irrelevantInput) val result = aes.expandedKey.map { it.foldIndexed(0U) { index, acc, uByte -> acc + (uByte.toUInt() shl (24 - index * 8)) @@ -181,6 +184,18 @@ class AesTest { expectedExpandedKey.contentEquals(result) } + } + @Test + fun testEncryption() { + val input = "3243f6a8885a308d313198a2e0370734" + val key = "2b7e151628aed2a6abf7158809cf4f3c" + val expectedResult = "3902dc1925dc116a8409850b1dfb9732" + + val aes = Aes(AesKey.Aes128Key(key), input.chunked(2).map { it.toInt(16).toUByte() }.toTypedArray()) + val result = aes.encrypt() + assertTrue { + result.contentEquals(expectedResult.chunked(2).map { it.toInt(16).toUByte() }.toTypedArray()) + } } } \ No newline at end of file